Optim#
Optimizer#
class Optimizer(params: 'Any') -> 'None':
Base class for stateful optimizers backed by pure functional steps.
Methods#
step#
def step(self, grads: 'Any' = None) -> 'Any':
SGD#
class SGD(params: 'Any', *, lr: 'float', momentum: 'float' = 0.0, weight_decay: 'float' = 0.0) -> 'None':
Stateful SGD optimizer with optional momentum and weight decay.
Usage::
optimizer = SGD(params, lr=0.01, momentum=0.9)
new_params = optimizer.step(grads)
Methods#
step#
def step(self, grads: 'Any' = None) -> 'Any':
AdamW#
class AdamW(params: 'Any', *, lr: 'float', betas: 'tuple[float, float]' = (0.9, 0.999), eps: 'float' = 1e-08, weight_decay: 'float' = 0.0) -> 'None':
Stateful AdamW optimizer with decoupled weight decay.
Implements the Adam algorithm with decoupled weight decay regularisation from Loshchilov & Hutter (2019).
Parameters
params– Model parameters (a tensor or pytree of tensors).lr– Learning rate.betas– Coefficients for computing running averages of gradient and its square. Default:(0.9, 0.999).eps– Small constant for numerical stability. Default:1e-8.weight_decay– Decoupled weight decay coefficient. Default:0.0.
Methods#
step#
def step(self, grads: 'Any' = None) -> 'Any':
sgd_step#
def sgd_step(param: 'Tensor', grad: 'Tensor', momentum_buffer: 'Tensor | None' = None, *, lr: 'float', weight_decay: 'float' = 0.0, momentum: 'float' = 0.0) -> 'tuple[Tensor, Tensor | None]':
Single-tensor SGD update.
Returns (new_param, new_momentum_buffer).
adamw_step#
def adamw_step(param: 'Tensor', grad: 'Tensor', m: 'Tensor', v: 'Tensor', step: 'int | float | Tensor', *, lr: 'float', beta1: 'float' = 0.9, beta2: 'float' = 0.999, eps: 'float' = 1e-08, weight_decay: 'float' = 0.0, bias_correction: 'bool' = True) -> 'tuple[Tensor, Tensor, Tensor]':
Single-tensor AdamW update.
Handles both scalar and tensor step (the latter is needed inside
@nb.compile where the step counter lives as a 0-D tensor).
sgd_update#
def sgd_update(params: 'Any', grads: 'Any', state: 'dict[str, Any] | None' = None, *, lr: 'float', momentum: 'float' = 0.0, weight_decay: 'float' = 0.0) -> 'tuple[Any, dict[str, Any]]':
Functional SGD update on pytrees (mirrors adamw_update).
Parameters
params:pytree– Current model parameters.grads:pytree– Gradients matching the params structure.state:dict, optional – Optimizer state containing"momentum_buffers"and"step". If None a fresh state is created.lr, momentum, weight_decay:float– Standard SGD hyper-parameters.
Returns
tuple – Updated parameters and optimizer state, with tensors realized
according to the global Optimizer execution policy.
adamw_init#
def adamw_init(params: 'Any') -> 'dict[str, Any]':
Functional AdamW state init for pytree params.
adamw_update#
def adamw_update(params: 'Any', grads: 'Any', state: 'dict[str, Any]', *, lr: 'float', weight_decay: 'float' = 0.0, beta1: 'float' = 0.9, beta2: 'float' = 0.999, eps: 'float' = 1e-08, bias_correction: 'bool' = True, realize: 'bool | None' = None) -> 'tuple[Any, dict[str, Any]]':
Functional AdamW update on pytrees.
Delegates per-leaf math to :func:adamw_step so the update logic
lives in one place. Handles both scalar and tensor step (the
latter is produced by _normalize_optimizer_state_for_compile
inside @nb.compile).