Source code for nabla.nn.optim.adamw
# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""AdamW optimizer implementation."""
import numpy as np
import nabla as nb
@nb.jit
def adamw_step(
params: list[nb.Array],
gradients: list[nb.Array],
m_states: list[nb.Array],
v_states: list[nb.Array],
step: int,
learning_rate: float = 0.001,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-8,
weight_decay: float = 0.01,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array]]:
"""JIT-compiled AdamW optimizer step with weight decay.
AdamW decouples weight decay from the gradient-based update, applying
weight decay directly to parameters rather than adding L2 regularization
to the loss function.
Args:
params: List of parameter arrays
gradients: List of gradient arrays (same structure as params)
m_states: List of first moment estimates
v_states: List of second moment estimates
step: Current step number (for bias correction)
learning_rate: Learning rate
beta1: First moment decay rate
beta2: Second moment decay rate
eps: Small constant for numerical stability
weight_decay: Weight decay coefficient
Returns:
Tuple of (updated_params, updated_m_states, updated_v_states)
"""
updated_params = []
updated_m = []
updated_v = []
for param, grad, m, v in zip(params, gradients, m_states, v_states, strict=False):
# Update biased first and second moment estimates
new_m = beta1 * m + (1.0 - beta1) * grad
new_v = beta2 * v + (1.0 - beta2) * (grad * grad)
# Bias correction
m_hat = new_m / (1.0 - beta1**step)
v_hat = new_v / (1.0 - beta2**step)
# AdamW update: weight decay applied directly to parameters
new_param = param * (
1.0 - weight_decay * learning_rate
) - learning_rate * m_hat / (nb.sqrt(v_hat) + eps)
updated_params.append(new_param)
updated_m.append(new_m)
updated_v.append(new_v)
return updated_params, updated_m, updated_v
[docs]
def init_adamw_state(params: list[nb.Array]) -> tuple[list[nb.Array], list[nb.Array]]:
"""Initialize AdamW optimizer state.
Args:
params: List of parameter arrays
Returns:
Tuple of (m_states, v_states) - first and second moment estimates
"""
m_states = []
v_states = []
for param in params:
# Initialize first and second moments to zero
m_np = np.zeros_like(param.to_numpy())
v_np = np.zeros_like(param.to_numpy())
m_states.append(nb.Array.from_numpy(m_np))
v_states.append(nb.Array.from_numpy(v_np))
return m_states, v_states