Optim
Optimizers (SGD, Adam, etc.)
Submodule Overview
-
nabla.nn.optim.adamw_step(*args)
-
nabla.nn.optim.init_adamw_state(params)[source]
Initialize AdamW optimizer state.
- Parameters:
params (list[Array]) – List of parameter arrays
- Returns:
Tuple of (m_states, v_states) - first and second moment estimates
- Return type:
tuple[list[Array], list[Array]]
-
nabla.nn.optim.adam_step(*args)
-
nabla.nn.optim.init_adam_state(params)[source]
Initialize Adam optimizer states.
- Parameters:
params (list[Array]) – List of parameter arrays
- Returns:
Tuple of (m_states, v_states) - zero-initialized moment estimates
- Return type:
tuple[list[Array], list[Array]]
-
nabla.nn.optim.sgd_step(*args)
-
nabla.nn.optim.init_sgd_state(params)[source]
Initialize SGD momentum states.
- Parameters:
params (list[Array]) – List of parameter arrays
- Returns:
List of zero-initialized momentum states
- Return type:
list[Array]
-
nabla.nn.optim.learning_rate_schedule(epoch, initial_lr=0.001, decay_factor=0.95, decay_every=1000)[source]
Learning rate schedule for complex function learning.
This is the original function from mlp_train_jit.py for backward compatibility.
Consider using exponential_decay_schedule instead for new code.
- Parameters:
epoch (int) – Current epoch number
initial_lr (float) – Initial learning rate
decay_factor (float) – Factor to multiply learning rate by
decay_every (int) – Apply decay every N epochs
- Returns:
Learning rate for the current epoch
- Return type:
float
-
nabla.nn.optim.constant_schedule(initial_lr=0.001)[source]
Constant learning rate schedule.
- Parameters:
initial_lr (float) – The learning rate to maintain
- Returns:
Function that takes epoch and returns learning rate
- Return type:
Callable[[int], float]
-
nabla.nn.optim.exponential_decay_schedule(initial_lr=0.001, decay_factor=0.95, decay_every=1000)[source]
Exponential decay learning rate schedule.
- Parameters:
initial_lr (float) – Initial learning rate
decay_factor (float) – Factor to multiply learning rate by
decay_every (int) – Apply decay every N epochs
- Returns:
Function that takes epoch and returns learning rate
- Return type:
Callable[[int], float]
-
nabla.nn.optim.step_decay_schedule(initial_lr=0.001, decay_factor=0.1, step_size=30)[source]
Step decay learning rate schedule.
- Parameters:
initial_lr (float) – Initial learning rate
decay_factor (float) – Factor to multiply learning rate by at each step
step_size (int) – Number of epochs between each decay step
- Returns:
Function that takes epoch and returns learning rate
- Return type:
Callable[[int], float]
-
nabla.nn.optim.cosine_annealing_schedule(initial_lr=0.001, min_lr=1e-06, period=1000)[source]
Cosine annealing learning rate schedule.
- Parameters:
initial_lr (float) – Initial learning rate
min_lr (float) – Minimum learning rate
period (int) – Number of epochs for one complete cosine cycle
- Returns:
Function that takes epoch and returns learning rate
- Return type:
Callable[[int], float]
-
nabla.nn.optim.warmup_cosine_schedule(initial_lr=0.001, warmup_epochs=100, total_epochs=1000, min_lr=1e-06)[source]
Warmup followed by cosine annealing schedule.
- Parameters:
initial_lr (float) – Peak learning rate after warmup
warmup_epochs (int) – Number of epochs for linear warmup
total_epochs (int) – Total number of training epochs
min_lr (float) – Minimum learning rate
- Returns:
Function that takes epoch and returns learning rate
- Return type:
Callable[[int], float]