Nn Module#

Neural network components and utilities

Module Overview#

Neural Network module for Nabla.

nabla.nn.mean_squared_error(predictions, targets)[source]#

Compute mean squared error loss.

Parameters:
  • predictions (Array) – Predicted values of shape (batch_size, …)

  • targets (Array) – Target values of shape (batch_size, …)

Returns:

Scalar loss value

Return type:

Array

nabla.nn.mean_absolute_error(predictions, targets)[source]#

Compute mean absolute error loss.

Parameters:
  • predictions (Array) – Predicted values of shape (batch_size, …)

  • targets (Array) – Target values of shape (batch_size, …)

Returns:

Scalar loss value

Return type:

Array

nabla.nn.huber_loss(predictions, targets, delta=1.0)[source]#

Compute Huber loss (smooth L1 loss).

Parameters:
  • predictions (Array) – Predicted values of shape (batch_size, …)

  • targets (Array) – Target values of shape (batch_size, …)

  • delta (float) – Threshold for switching between L1 and L2 loss

Returns:

Scalar loss value

Return type:

Array

nabla.nn.cross_entropy_loss(logits, targets, axis=-1)[source]#

Compute cross-entropy loss between logits and targets.

Parameters:
  • logits (Array) – Raw model outputs (before softmax) [batch_size, num_classes]

  • targets (Array) – One-hot encoded targets [batch_size, num_classes]

  • axis (int) – Axis along which to compute softmax

Returns:

Scalar loss value

Return type:

Array

nabla.nn.sparse_cross_entropy_loss(logits, targets, axis=-1)[source]#

Compute cross-entropy loss with integer targets.

Parameters:
  • logits (Array) – Raw model outputs [batch_size, num_classes]

  • targets (Array) – Integer class indices [batch_size]

  • axis (int) – Axis along which to compute softmax

Returns:

Scalar loss value

Return type:

Array

nabla.nn.binary_cross_entropy_loss(predictions, targets, eps=1e-07)[source]#

Compute binary cross-entropy loss.

Parameters:
  • predictions (Array) – Model predictions (after sigmoid) [batch_size]

  • targets (Array) – Binary targets (0 or 1) [batch_size]

  • eps (float) – Small constant for numerical stability

Returns:

Scalar loss value

Return type:

Array

nabla.nn.softmax_cross_entropy_loss(logits, targets, axis=-1)[source]#

Compute softmax cross-entropy loss (numerically stable).

This is equivalent to cross_entropy_loss but more numerically stable by combining softmax and cross-entropy computations.

Parameters:
  • logits (Array) – Raw model outputs [batch_size, num_classes]

  • targets (Array) – One-hot encoded targets [batch_size, num_classes]

  • axis (int) – Axis along which to compute softmax

Returns:

Scalar loss value

Return type:

Array

nabla.nn.adamw_step(*args)#
nabla.nn.init_adamw_state(params)[source]#

Initialize AdamW optimizer state.

Parameters:

params (list[Array]) – List of parameter arrays

Returns:

Tuple of (m_states, v_states) - first and second moment estimates

Return type:

tuple[list[Array], list[Array]]

nabla.nn.adam_step(*args)#
nabla.nn.init_adam_state(params)[source]#

Initialize Adam optimizer states.

Parameters:

params (list[Array]) – List of parameter arrays

Returns:

Tuple of (m_states, v_states) - zero-initialized moment estimates

Return type:

tuple[list[Array], list[Array]]

nabla.nn.sgd_step(*args)#
nabla.nn.init_sgd_state(params)[source]#

Initialize SGD momentum states.

Parameters:

params (list[Array]) – List of parameter arrays

Returns:

List of zero-initialized momentum states

Return type:

list[Array]

nabla.nn.learning_rate_schedule(epoch, initial_lr=0.001, decay_factor=0.95, decay_every=1000)[source]#

Learning rate schedule for complex function learning.

This is the original function from mlp_train_jit.py for backward compatibility. Consider using exponential_decay_schedule instead for new code.

Parameters:
  • epoch (int) – Current epoch number

  • initial_lr (float) – Initial learning rate

  • decay_factor (float) – Factor to multiply learning rate by

  • decay_every (int) – Apply decay every N epochs

Returns:

Learning rate for the current epoch

Return type:

float

nabla.nn.exponential_decay_schedule(initial_lr=0.001, decay_factor=0.95, decay_every=1000)[source]#

Exponential decay learning rate schedule.

Parameters:
  • initial_lr (float) – Initial learning rate

  • decay_factor (float) – Factor to multiply learning rate by

  • decay_every (int) – Apply decay every N epochs

Returns:

Function that takes epoch and returns learning rate

Return type:

Callable[[int], float]

nabla.nn.cosine_annealing_schedule(initial_lr=0.001, min_lr=1e-06, period=1000)[source]#

Cosine annealing learning rate schedule.

Parameters:
  • initial_lr (float) – Initial learning rate

  • min_lr (float) – Minimum learning rate

  • period (int) – Number of epochs for one complete cosine cycle

Returns:

Function that takes epoch and returns learning rate

Return type:

Callable[[int], float]

nabla.nn.warmup_cosine_schedule(initial_lr=0.001, warmup_epochs=100, total_epochs=1000, min_lr=1e-06)[source]#

Warmup followed by cosine annealing schedule.

Parameters:
  • initial_lr (float) – Peak learning rate after warmup

  • warmup_epochs (int) – Number of epochs for linear warmup

  • total_epochs (int) – Total number of training epochs

  • min_lr (float) – Minimum learning rate

Returns:

Function that takes epoch and returns learning rate

Return type:

Callable[[int], float]

nabla.nn.initialize_mlp_params(layers, seed=42)[source]#

Initialize MLP parameters with specialized strategy for complex functions.

This is the original initialization strategy from mlp_train_jit.py, optimized for learning high-frequency functions.

Parameters:
  • layers (list[int]) – List of layer sizes [input, hidden1, hidden2, …, output]

  • seed (int) – Random seed for reproducibility

Returns:

List of parameter arrays [W1, b1, W2, b2, …]

Return type:

list[Array]

nabla.nn.he_normal(shape, seed=None)[source]#

He normal initialization for ReLU networks.

Uses normal distribution with std = sqrt(2/fan_in) which is optimal for ReLU activations.

Parameters:
  • shape (tuple[int, ...]) – Shape of the parameter tensor

  • seed (int | None) – Random seed for reproducibility

Returns:

Initialized parameter array

Return type:

Array

nabla.nn.xavier_normal(shape, seed=None)[source]#

Xavier/Glorot normal initialization.

Uses normal distribution with std = sqrt(2/(fan_in + fan_out)) which is optimal for sigmoid/tanh activations.

Parameters:
  • shape (tuple[int, ...]) – Shape of the parameter tensor

  • seed (int | None) – Random seed for reproducibility

Returns:

Initialized parameter array

Return type:

Array

nabla.nn.lecun_normal(shape, seed=None)[source]#

LeCun normal initialization.

Uses normal distribution with std = sqrt(1/fan_in) which is optimal for SELU activations.

Parameters:
  • shape (tuple[int, ...]) – Shape of the parameter tensor

  • seed (int | None) – Random seed for reproducibility

Returns:

Initialized parameter array

Return type:

Array

nabla.nn.mlp_forward(x, params)[source]#

MLP forward pass through all layers.

This is the original MLP forward function from mlp_train_jit.py. Applies ReLU activation to all layers except the last.

Parameters:
  • x (Array) – Input tensor of shape (batch_size, input_dim)

  • params (list[Array]) – List of parameters [W1, b1, W2, b2, …, Wn, bn]

Returns:

Output tensor of shape (batch_size, output_dim)

Return type:

Array

nabla.nn.linear_forward(x, weight, bias=None)[source]#

Forward pass through a linear layer.

Computes: output = x @ weight + bias

Parameters:
  • x (Array) – Input tensor of shape (batch_size, in_features)

  • weight (Array) – Weight tensor of shape (in_features, out_features)

  • bias (Array | None) – Optional bias tensor of shape (1, out_features) or (out_features,)

Returns:

Output tensor of shape (batch_size, out_features)

Return type:

Array

nabla.nn.mlp_forward_with_activations(x, params, activation='relu', final_activation=None)[source]#

MLP forward pass with configurable activations.

Parameters:
  • x (Array) – Input tensor of shape (batch_size, input_dim)

  • params (list[Array]) – List of parameters [W1, b1, W2, b2, …, Wn, bn]

  • activation (str) – Activation function for hidden layers (“relu”, “tanh”, “sigmoid”)

  • final_activation (str | None) – Optional activation for final layer

Returns:

Output tensor of shape (batch_size, output_dim)

Return type:

Array

nabla.nn.relu(x)[source]#

Rectified Linear Unit activation function.

Parameters:

x (Array) – Input array

Returns:

Array with ReLU applied element-wise

Return type:

Array

nabla.nn.leaky_relu(x, negative_slope=0.01)[source]#

Leaky ReLU activation function.

Parameters:
  • x (Array) – Input array

  • negative_slope (float) – Slope for negative values

Returns:

Array with Leaky ReLU applied element-wise

Return type:

Array

nabla.nn.sigmoid(x)[source]#

Sigmoid activation function.

Parameters:

x (Array) – Input array

Returns:

Array with sigmoid applied element-wise

Return type:

Array

nabla.nn.tanh(x)[source]#

Hyperbolic tangent activation function.

Parameters:

x (Array) – Input array

Returns:

Array with tanh applied element-wise

Return type:

Array

nabla.nn.gelu(x)[source]#

Gaussian Error Linear Unit activation function.

GELU(x) = x * Φ(x) where Φ(x) is the CDF of standard normal distribution. Approximation: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))

Parameters:

x (Array) – Input array

Returns:

Array with GELU applied element-wise

Return type:

Array

nabla.nn.silu(x)[source]#

Sigmoid Linear Unit (SiLU) activation function.

SiLU(x) = x * sigmoid(x) = Swish(x, β=1)

Parameters:

x (Array) – Input array

Returns:

Array with SiLU applied element-wise

Return type:

Array

nabla.nn.swish(x, beta=1.0)[source]#

Swish (SiLU) activation function.

Swish(x) = x * sigmoid(β * x) When β = 1, this is SiLU (Sigmoid Linear Unit).

Parameters:
  • x (Array) – Input array

  • beta (float) – Scaling factor for sigmoid

Returns:

Array with Swish applied element-wise

Return type:

Array

nabla.nn.softmax(x, axis=-1)[source]#

Softmax activation function.

Parameters:
  • x (Array) – Input array

  • axis (int) – Axis along which to compute softmax

Returns:

Array with softmax applied along specified axis

Return type:

Array

nabla.nn.log_softmax(x, axis=-1)[source]#

Log-softmax activation function.

Parameters:
  • x (Array) – Input array

  • axis (int) – Axis along which to compute log-softmax

Returns:

Array with log-softmax applied along specified axis

Return type:

Array

nabla.nn.create_mlp_config(layers, activation='relu', final_activation=None, init_method='he_normal', seed=42)[source]#

Create MLP configuration dictionary.

Parameters:
  • layers (list[int]) – List of layer sizes [input, hidden1, hidden2, …, output]

  • activation (str) – Activation function for hidden layers

  • final_activation (str | None) – Optional activation for final layer

  • init_method (str) – Weight initialization method

  • seed (int) – Random seed for reproducibility

Returns:

Configuration dictionary with params and forward function

Return type:

dict

class nabla.nn.MLPBuilder[source]#

Bases: object

Builder class for creating MLP configurations.

__init__()[source]#
with_layers(layers)[source]#

Set layer sizes.

with_activation(activation)[source]#

Set hidden layer activation function.

with_final_activation(activation)[source]#

Set final layer activation function.

with_init_method(method)[source]#

Set weight initialization method.

with_seed(seed)[source]#

Set random seed.

build()[source]#

Build the MLP configuration.

nabla.nn.create_sin_dataset(batch_size=256, sin_periods=8)[source]#

Create the complex 8-period sin dataset from mlp_train_jit.py.

Parameters:
  • batch_size (int) – Number of samples to generate

  • sin_periods (int) – Number of sin periods in [0, 1] interval

Returns:

Tuple of (x, targets) where targets are sin function values

Return type:

tuple[Array, Array]

nabla.nn.create_dataset(batch_size, input_dim, seed=None)[source]#

Create a simple random dataset for testing.

Parameters:
  • batch_size (int) – Number of samples

  • input_dim (int) – Input dimension

  • seed (int | None) – Random seed for reproducibility

Returns:

Tuple of (inputs, targets)

Return type:

tuple[Array, Array]

nabla.nn.accuracy(predictions, targets)[source]#

Compute classification accuracy.

Parameters:
  • predictions (Array) – Model predictions - either logits/probabilities [batch_size, num_classes] or class indices [batch_size]

  • targets (Array) – True labels - either one-hot [batch_size, num_classes] or indices [batch_size]

Returns:

Scalar accuracy value between 0 and 1

Return type:

Array

nabla.nn.precision(predictions, targets, num_classes, class_idx=0)[source]#

Compute precision for a specific class.

Precision = TP / (TP + FP)

Parameters:
  • predictions (Array) – Model predictions (logits) [batch_size, num_classes]

  • targets (Array) – True labels (sparse) [batch_size]

  • num_classes (int) – Total number of classes

  • class_idx (int) – Class index to compute precision for

Returns:

Scalar precision value for the specified class

Return type:

Array

nabla.nn.recall(predictions, targets, num_classes, class_idx=0)[source]#

Compute recall for a specific class.

Recall = TP / (TP + FN)

Parameters:
  • predictions (Array) – Model predictions (logits) [batch_size, num_classes]

  • targets (Array) – True labels (sparse) [batch_size]

  • num_classes (int) – Total number of classes

  • class_idx (int) – Class index to compute recall for

Returns:

Scalar recall value for the specified class

Return type:

Array

nabla.nn.f1_score(predictions, targets, num_classes, class_idx=0)[source]#

Compute F1 score for a specific class.

F1 = 2 * (precision * recall) / (precision + recall)

Parameters:
  • predictions (Array) – Model predictions (logits) [batch_size, num_classes]

  • targets (Array) – True labels (sparse) [batch_size]

  • num_classes (int) – Total number of classes

  • class_idx (int) – Class index to compute F1 score for

Returns:

Scalar F1 score for the specified class

Return type:

Array

nabla.nn.mean_squared_error_metric(predictions, targets)[source]#

Compute MSE metric for regression tasks.

Parameters:
  • predictions (Array) – Model predictions [batch_size, …]

  • targets (Array) – True targets [batch_size, …]

Returns:

Scalar MSE value

Return type:

Array

nabla.nn.mean_absolute_error_metric(predictions, targets)[source]#

Compute MAE metric for regression tasks.

Parameters:
  • predictions (Array) – Model predictions [batch_size, …]

  • targets (Array) – True targets [batch_size, …]

Returns:

Scalar MAE value

Return type:

Array

nabla.nn.r_squared(predictions, targets)[source]#

Compute R-squared (coefficient of determination) for regression tasks.

R² = 1 - (SS_res / SS_tot) where SS_res = Σ(y_true - y_pred)² and SS_tot = Σ(y_true - y_mean)²

Parameters:
  • predictions (Array) – Model predictions [batch_size, …]

  • targets (Array) – True targets [batch_size, …]

Returns:

Scalar R² value

Return type:

Array

nabla.nn.pearson_correlation(predictions, targets)[source]#

Compute Pearson correlation coefficient.

Parameters:
  • predictions (Array) – Model predictions [batch_size, …]

  • targets (Array) – True targets [batch_size, …]

Returns:

Scalar correlation coefficient

Return type:

Array

nabla.nn.dropout(x, p=0.5, training=True, seed=None)[source]#

Apply dropout regularization.

During training, randomly sets elements to zero with probability p. During inference, scales all elements by (1-p) to maintain expected values.

Parameters:
  • x (Array) – Input array

  • p (float) – Dropout probability (fraction of elements to set to zero)

  • training (bool) – Whether in training mode (apply dropout) or inference mode

  • seed (int | None) – Random seed for reproducibility

Returns:

Array with dropout applied

Return type:

Array

nabla.nn.l1_regularization(params, weight=0.01)[source]#

Compute L1 (Lasso) regularization loss.

L1 regularization adds a penalty equal to the sum of absolute values of parameters. This encourages sparsity in the model parameters.

Parameters:
  • params (list[Array]) – List of parameter arrays (typically weights)

  • weight (float) – Regularization strength

Returns:

Scalar L1 regularization loss

Return type:

Array

nabla.nn.l2_regularization(params, weight=0.01)[source]#

Compute L2 (Ridge) regularization loss.

L2 regularization adds a penalty equal to the sum of squares of parameters. This encourages small parameter values and helps prevent overfitting.

Parameters:
  • params (list[Array]) – List of parameter arrays (typically weights)

  • weight (float) – Regularization strength

Returns:

Scalar L2 regularization loss

Return type:

Array

nabla.nn.elastic_net_regularization(params, l1_weight=0.01, l2_weight=0.01, l1_ratio=0.5)[source]#

Compute Elastic Net regularization loss.

Elastic Net combines L1 and L2 regularization: ElasticNet = l1_ratio * L1 + (1 - l1_ratio) * L2

Parameters:
  • params (list[Array]) – List of parameter arrays (typically weights)

  • l1_weight (float) – L1 regularization strength

  • l2_weight (float) – L2 regularization strength

  • l1_ratio (float) – Ratio of L1 to L2 regularization (0 = pure L2, 1 = pure L1)

Returns:

Scalar Elastic Net regularization loss

Return type:

Array

nabla.nn.gradient_clipping(gradients, max_norm=1.0, norm_type='l2')[source]#

Apply gradient clipping to prevent exploding gradients.

Parameters:
  • gradients (list[Array]) – List of gradient arrays

  • max_norm (float) – Maximum allowed gradient norm

  • norm_type (str) – Type of norm to use (“l2” or “l1”)

Returns:

Tuple of (clipped_gradients, total_norm)

Return type:

tuple[list[Array], Array]