Source code for nabla.nn.optim.adam

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Adam optimizer implementation."""

import nabla as nb


@nb.jit
def adam_step(
    params: list[nb.Array],
    gradients: list[nb.Array],
    m_states: list[nb.Array],
    v_states: list[nb.Array],
    step: int,
    learning_rate: float = 0.001,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    weight_decay: float = 0.0,
    amsgrad: bool = False,
    maximize: bool = False,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array]]:
    """Perform a single Adam optimization step.

    Args:
        params: List of parameter arrays
        gradients: List of gradient arrays (same structure as params)
        m_states: List of first moment estimates
        v_states: List of second moment estimates
        step: Current step number (for bias correction)
        learning_rate: Learning rate
        beta1: Exponential decay rate for first moment estimates
        beta2: Exponential decay rate for second moment estimates
        eps: Small constant for numerical stability
        weight_decay: Weight decay (L2 penalty)
        amsgrad: Whether to use AMSGrad variant
        maximize: Maximize instead of minimize

    Returns:
        Tuple of (updated_params, updated_m_states, updated_v_states)
    """
    updated_params = []
    updated_m_states = []
    updated_v_states = []

    # Bias correction terms
    bias_correction1 = 1 - beta1**step
    bias_correction2 = 1 - beta2**step

    for param, grad, m_state, v_state in zip(
        params, gradients, m_states, v_states, strict=False
    ):
        # Add weight decay
        if weight_decay != 0:
            grad = grad + weight_decay * param

        # Maximize by negating gradients
        if maximize:
            grad = -grad

        # Update biased first moment estimate
        m_new = beta1 * m_state + (1 - beta1) * grad

        # Update biased second raw moment estimate
        v_new = beta2 * v_state + (1 - beta2) * (grad * grad)

        # Compute bias-corrected first moment estimate
        m_hat = m_new / bias_correction1

        # Compute bias-corrected second raw moment estimate
        v_hat = v_new / bias_correction2

        # Update parameters
        denom = nb.sqrt(v_hat) + eps
        step_size = learning_rate

        updated_param = param - step_size * m_hat / denom

        updated_params.append(updated_param)
        updated_m_states.append(m_new)
        updated_v_states.append(v_new)

    return updated_params, updated_m_states, updated_v_states


[docs] def init_adam_state(params: list[nb.Array]) -> tuple[list[nb.Array], list[nb.Array]]: """Initialize Adam optimizer states. Args: params: List of parameter arrays Returns: Tuple of (m_states, v_states) - zero-initialized moment estimates """ m_states = [nb.zeros_like(param) for param in params] v_states = [nb.zeros_like(param) for param in params] return m_states, v_states
__all__ = ["adam_step", "init_adam_state"]