Source code for nabla.nn.optim.sgd

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""SGD optimizer implementation."""

import nabla as nb


@nb.jit
def sgd_step(
    params: list[nb.Array],
    gradients: list[nb.Array],
    momentum_states: list[nb.Array] | None = None,
    learning_rate: float = 0.01,
    momentum: float = 0.0,
    weight_decay: float = 0.0,
    dampening: float = 0.0,
    nesterov: bool = False,
) -> tuple[list[nb.Array], list[nb.Array]]:
    """Perform a single SGD optimization step.

    Args:
        params: List of parameter arrays
        gradients: List of gradient arrays (same structure as params)
        momentum_states: List of momentum buffers (None for first step)
        learning_rate: Learning rate
        momentum: Momentum factor
        weight_decay: Weight decay (L2 penalty)
        dampening: Dampening for momentum
        nesterov: Enable Nesterov momentum

    Returns:
        Tuple of (updated_params, updated_momentum_states)
    """
    updated_params = []
    updated_momentum_states = []

    for i, (param, grad) in enumerate(zip(params, gradients, strict=False)):
        # Add weight decay
        if weight_decay != 0:
            grad = grad + weight_decay * param

        # Initialize momentum state if needed
        if momentum_states is None or len(momentum_states) <= i:
            momentum_state = nb.zeros_like(param)
        else:
            momentum_state = momentum_states[i]

        # Update momentum
        if momentum != 0:
            if i == 0 or momentum_states is None:
                # First step or no previous momentum
                buf = grad
            else:
                buf = momentum * momentum_state + (1 - dampening) * grad

            grad = grad + momentum * buf if nesterov else buf

            updated_momentum_states.append(buf)
        else:
            updated_momentum_states.append(momentum_state)

        # Update parameters
        updated_param = param - learning_rate * grad
        updated_params.append(updated_param)

    return updated_params, updated_momentum_states


[docs] def init_sgd_state(params: list[nb.Array]) -> list[nb.Array]: """Initialize SGD momentum states. Args: params: List of parameter arrays Returns: List of zero-initialized momentum states """ return [nb.zeros_like(param) for param in params]
__all__ = ["sgd_step", "init_sgd_state"]