Optimizers (nabla.nn.optim)#
Optimizer#
class Optimizer(params: 'Any') -> 'None':
Base class for stateful optimizers backed by pure functional steps.
Methods#
step#
def step(self, grads: 'Any') -> 'Any':
SGD#
class SGD(params: 'Any', *, lr: 'float', momentum: 'float' = 0.0, weight_decay: 'float' = 0.0) -> 'None':
Stateful SGD optimizer with optional momentum and weight decay.
Usage::
optimizer = SGD(params, lr=0.01, momentum=0.9)
new_params = optimizer.step(grads)
Methods#
step#
def step(self, grads: 'Any') -> 'Any':
AdamW#
class AdamW(params: 'Any', *, lr: 'float', betas: 'tuple[float, float]' = (0.9, 0.999), eps: 'float' = 1e-08, weight_decay: 'float' = 0.0) -> 'None':
Base class for stateful optimizers backed by pure functional steps.
Methods#
step#
def step(self, grads: 'Any') -> 'Any':
sgd_step#
def sgd_step(param: 'Tensor', grad: 'Tensor', momentum_buffer: 'Tensor | None' = None, *, lr: 'float', weight_decay: 'float' = 0.0, momentum: 'float' = 0.0) -> 'tuple[Tensor, Tensor | None]':
Single-tensor SGD update.
Returns (new_param, new_momentum_buffer).
adamw_step#
def adamw_step(param: 'Tensor', grad: 'Tensor', m: 'Tensor', v: 'Tensor', step: 'int', *, lr: 'float', beta1: 'float' = 0.9, beta2: 'float' = 0.999, eps: 'float' = 1e-08, weight_decay: 'float' = 0.0) -> 'tuple[Tensor, Tensor, Tensor]':
Single-tensor AdamW update.
sgd_update#
def sgd_update(params: 'Any', grads: 'Any', state: 'dict[str, Any] | None' = None, *, lr: 'float', momentum: 'float' = 0.0, weight_decay: 'float' = 0.0) -> 'tuple[Any, dict[str, Any]]':
Functional SGD update on pytrees (mirrors adamw_update).
Parameters
params:pytree– Current model parameters.grads:pytree– Gradients matching the params structure.state:dict, optional – Optimizer state containing"momentum_buffers"and"step". If None a fresh state is created.lr, momentum, weight_decay:float– Standard SGD hyper-parameters.
Returns
tuple – Updated parameters and optimizer state, with tensors realized
according to the global Optimizer execution policy.
adamw_init#
def adamw_init(params: 'Any') -> 'dict[str, Any]':
Functional AdamW state init for pytree params.
adamw_update#
def adamw_update(params: 'Any', grads: 'Any', state: 'dict[str, Any]', *, lr: 'float', weight_decay: 'float' = 0.0, beta1: 'float' = 0.9, beta2: 'float' = 0.999, eps: 'float' = 1e-08, bias_correction: 'bool' = True, realize: 'bool | None' = None) -> 'tuple[Any, dict[str, Any]]':
Functional AdamW update on pytrees.
Kept for compatibility and reused by finetuning workloads.