Optim#

Optimizers (SGD, Adam, etc.)

Submodule Overview#

nabla.nn.optim.adamw_step(*args)#
nabla.nn.optim.init_adamw_state(params)[source]#

Initialize AdamW optimizer state.

Parameters:

params (list[Array]) – List of parameter arrays

Returns:

Tuple of (m_states, v_states) - first and second moment estimates

Return type:

tuple[list[Array], list[Array]]

nabla.nn.optim.adam_step(*args)#
nabla.nn.optim.init_adam_state(params)[source]#

Initialize Adam optimizer states.

Parameters:

params (list[Array]) – List of parameter arrays

Returns:

Tuple of (m_states, v_states) - zero-initialized moment estimates

Return type:

tuple[list[Array], list[Array]]

nabla.nn.optim.sgd_step(*args)#
nabla.nn.optim.init_sgd_state(params)[source]#

Initialize SGD momentum states.

Parameters:

params (list[Array]) – List of parameter arrays

Returns:

List of zero-initialized momentum states

Return type:

list[Array]

nabla.nn.optim.learning_rate_schedule(epoch, initial_lr=0.001, decay_factor=0.95, decay_every=1000)[source]#

Learning rate schedule for complex function learning.

This is the original function from mlp_train_jit.py for backward compatibility. Consider using exponential_decay_schedule instead for new code.

Parameters:
  • epoch (int) – Current epoch number

  • initial_lr (float) – Initial learning rate

  • decay_factor (float) – Factor to multiply learning rate by

  • decay_every (int) – Apply decay every N epochs

Returns:

Learning rate for the current epoch

Return type:

float

nabla.nn.optim.constant_schedule(initial_lr=0.001)[source]#

Constant learning rate schedule.

Parameters:

initial_lr (float) – The learning rate to maintain

Returns:

Function that takes epoch and returns learning rate

Return type:

Callable[[int], float]

nabla.nn.optim.exponential_decay_schedule(initial_lr=0.001, decay_factor=0.95, decay_every=1000)[source]#

Exponential decay learning rate schedule.

Parameters:
  • initial_lr (float) – Initial learning rate

  • decay_factor (float) – Factor to multiply learning rate by

  • decay_every (int) – Apply decay every N epochs

Returns:

Function that takes epoch and returns learning rate

Return type:

Callable[[int], float]

nabla.nn.optim.step_decay_schedule(initial_lr=0.001, decay_factor=0.1, step_size=30)[source]#

Step decay learning rate schedule.

Parameters:
  • initial_lr (float) – Initial learning rate

  • decay_factor (float) – Factor to multiply learning rate by at each step

  • step_size (int) – Number of epochs between each decay step

Returns:

Function that takes epoch and returns learning rate

Return type:

Callable[[int], float]

nabla.nn.optim.cosine_annealing_schedule(initial_lr=0.001, min_lr=1e-06, period=1000)[source]#

Cosine annealing learning rate schedule.

Parameters:
  • initial_lr (float) – Initial learning rate

  • min_lr (float) – Minimum learning rate

  • period (int) – Number of epochs for one complete cosine cycle

Returns:

Function that takes epoch and returns learning rate

Return type:

Callable[[int], float]

nabla.nn.optim.warmup_cosine_schedule(initial_lr=0.001, warmup_epochs=100, total_epochs=1000, min_lr=1e-06)[source]#

Warmup followed by cosine annealing schedule.

Parameters:
  • initial_lr (float) – Peak learning rate after warmup

  • warmup_epochs (int) – Number of epochs for linear warmup

  • total_epochs (int) – Total number of training epochs

  • min_lr (float) – Minimum learning rate

Returns:

Function that takes epoch and returns learning rate

Return type:

Callable[[int], float]