Nn Module
Neural network components and utilities
Module Overview
Neural Network module for Nabla.
-
nabla.nn.mean_squared_error(predictions, targets)[source]
Compute mean squared error loss.
- Parameters:
predictions (Array) – Predicted values of shape (batch_size, …)
targets (Array) – Target values of shape (batch_size, …)
- Returns:
Scalar loss value
- Return type:
Array
-
nabla.nn.mean_absolute_error(predictions, targets)[source]
Compute mean absolute error loss.
- Parameters:
predictions (Array) – Predicted values of shape (batch_size, …)
targets (Array) – Target values of shape (batch_size, …)
- Returns:
Scalar loss value
- Return type:
Array
-
nabla.nn.huber_loss(predictions, targets, delta=1.0)[source]
Compute Huber loss (smooth L1 loss).
- Parameters:
predictions (Array) – Predicted values of shape (batch_size, …)
targets (Array) – Target values of shape (batch_size, …)
delta (float) – Threshold for switching between L1 and L2 loss
- Returns:
Scalar loss value
- Return type:
Array
-
nabla.nn.cross_entropy_loss(logits, targets, axis=-1)[source]
Compute cross-entropy loss between logits and targets.
- Parameters:
logits (Array) – Raw model outputs (before softmax) [batch_size, num_classes]
targets (Array) – One-hot encoded targets [batch_size, num_classes]
axis (int) – Axis along which to compute softmax
- Returns:
Scalar loss value
- Return type:
Array
-
nabla.nn.sparse_cross_entropy_loss(logits, targets, axis=-1)[source]
Compute cross-entropy loss with integer targets.
- Parameters:
logits (Array) – Raw model outputs [batch_size, num_classes]
targets (Array) – Integer class indices [batch_size]
axis (int) – Axis along which to compute softmax
- Returns:
Scalar loss value
- Return type:
Array
-
nabla.nn.binary_cross_entropy_loss(predictions, targets, eps=1e-07)[source]
Compute binary cross-entropy loss.
- Parameters:
predictions (Array) – Model predictions (after sigmoid) [batch_size]
targets (Array) – Binary targets (0 or 1) [batch_size]
eps (float) – Small constant for numerical stability
- Returns:
Scalar loss value
- Return type:
Array
-
nabla.nn.softmax_cross_entropy_loss(logits, targets, axis=-1)[source]
Compute softmax cross-entropy loss (numerically stable).
This is equivalent to cross_entropy_loss but more numerically stable
by combining softmax and cross-entropy computations.
- Parameters:
logits (Array) – Raw model outputs [batch_size, num_classes]
targets (Array) – One-hot encoded targets [batch_size, num_classes]
axis (int) – Axis along which to compute softmax
- Returns:
Scalar loss value
- Return type:
Array
-
nabla.nn.adamw_step(*args)
-
nabla.nn.init_adamw_state(params)[source]
Initialize AdamW optimizer state.
- Parameters:
params (list[Array]) – List of parameter arrays
- Returns:
Tuple of (m_states, v_states) - first and second moment estimates
- Return type:
tuple[list[Array], list[Array]]
-
nabla.nn.adam_step(*args)
-
nabla.nn.init_adam_state(params)[source]
Initialize Adam optimizer states.
- Parameters:
params (list[Array]) – List of parameter arrays
- Returns:
Tuple of (m_states, v_states) - zero-initialized moment estimates
- Return type:
tuple[list[Array], list[Array]]
-
nabla.nn.sgd_step(*args)
-
nabla.nn.init_sgd_state(params)[source]
Initialize SGD momentum states.
- Parameters:
params (list[Array]) – List of parameter arrays
- Returns:
List of zero-initialized momentum states
- Return type:
list[Array]
-
nabla.nn.learning_rate_schedule(epoch, initial_lr=0.001, decay_factor=0.95, decay_every=1000)[source]
Learning rate schedule for complex function learning.
This is the original function from mlp_train_jit.py for backward compatibility.
Consider using exponential_decay_schedule instead for new code.
- Parameters:
epoch (int) – Current epoch number
initial_lr (float) – Initial learning rate
decay_factor (float) – Factor to multiply learning rate by
decay_every (int) – Apply decay every N epochs
- Returns:
Learning rate for the current epoch
- Return type:
float
-
nabla.nn.exponential_decay_schedule(initial_lr=0.001, decay_factor=0.95, decay_every=1000)[source]
Exponential decay learning rate schedule.
- Parameters:
initial_lr (float) – Initial learning rate
decay_factor (float) – Factor to multiply learning rate by
decay_every (int) – Apply decay every N epochs
- Returns:
Function that takes epoch and returns learning rate
- Return type:
Callable[[int], float]
-
nabla.nn.cosine_annealing_schedule(initial_lr=0.001, min_lr=1e-06, period=1000)[source]
Cosine annealing learning rate schedule.
- Parameters:
initial_lr (float) – Initial learning rate
min_lr (float) – Minimum learning rate
period (int) – Number of epochs for one complete cosine cycle
- Returns:
Function that takes epoch and returns learning rate
- Return type:
Callable[[int], float]
-
nabla.nn.warmup_cosine_schedule(initial_lr=0.001, warmup_epochs=100, total_epochs=1000, min_lr=1e-06)[source]
Warmup followed by cosine annealing schedule.
- Parameters:
initial_lr (float) – Peak learning rate after warmup
warmup_epochs (int) – Number of epochs for linear warmup
total_epochs (int) – Total number of training epochs
min_lr (float) – Minimum learning rate
- Returns:
Function that takes epoch and returns learning rate
- Return type:
Callable[[int], float]
-
nabla.nn.initialize_mlp_params(layers, seed=42)[source]
Initialize MLP parameters with specialized strategy for complex functions.
This is the original initialization strategy from mlp_train_jit.py,
optimized for learning high-frequency functions.
- Parameters:
layers (list[int]) – List of layer sizes [input, hidden1, hidden2, …, output]
seed (int) – Random seed for reproducibility
- Returns:
List of parameter arrays [W1, b1, W2, b2, …]
- Return type:
list[Array]
-
nabla.nn.he_normal(shape, seed=None)[source]
He normal initialization for ReLU networks.
Uses normal distribution with std = sqrt(2/fan_in) which is optimal
for ReLU activations.
- Parameters:
shape (tuple[int, ...]) – Shape of the parameter tensor
seed (int | None) – Random seed for reproducibility
- Returns:
Initialized parameter array
- Return type:
Array
-
nabla.nn.xavier_normal(shape, seed=None)[source]
Xavier/Glorot normal initialization.
Uses normal distribution with std = sqrt(2/(fan_in + fan_out)) which
is optimal for sigmoid/tanh activations.
- Parameters:
shape (tuple[int, ...]) – Shape of the parameter tensor
seed (int | None) – Random seed for reproducibility
- Returns:
Initialized parameter array
- Return type:
Array
-
nabla.nn.lecun_normal(shape, seed=None)[source]
LeCun normal initialization.
Uses normal distribution with std = sqrt(1/fan_in) which is optimal
for SELU activations.
- Parameters:
shape (tuple[int, ...]) – Shape of the parameter tensor
seed (int | None) – Random seed for reproducibility
- Returns:
Initialized parameter array
- Return type:
Array
-
nabla.nn.mlp_forward(x, params)[source]
MLP forward pass through all layers.
This is the original MLP forward function from mlp_train_jit.py.
Applies ReLU activation to all layers except the last.
- Parameters:
x (Array) – Input tensor of shape (batch_size, input_dim)
params (list[Array]) – List of parameters [W1, b1, W2, b2, …, Wn, bn]
- Returns:
Output tensor of shape (batch_size, output_dim)
- Return type:
Array
-
nabla.nn.linear_forward(x, weight, bias=None)[source]
Forward pass through a linear layer.
Computes: output = x @ weight + bias
- Parameters:
x (Array) – Input tensor of shape (batch_size, in_features)
weight (Array) – Weight tensor of shape (in_features, out_features)
bias (Array | None) – Optional bias tensor of shape (1, out_features) or (out_features,)
- Returns:
Output tensor of shape (batch_size, out_features)
- Return type:
Array
-
nabla.nn.mlp_forward_with_activations(x, params, activation='relu', final_activation=None)[source]
MLP forward pass with configurable activations.
- Parameters:
x (Array) – Input tensor of shape (batch_size, input_dim)
params (list[Array]) – List of parameters [W1, b1, W2, b2, …, Wn, bn]
activation (str) – Activation function for hidden layers (“relu”, “tanh”, “sigmoid”)
final_activation (str | None) – Optional activation for final layer
- Returns:
Output tensor of shape (batch_size, output_dim)
- Return type:
Array
-
nabla.nn.relu(x)[source]
Rectified Linear Unit activation function.
- Parameters:
x (Array) – Input array
- Returns:
Array with ReLU applied element-wise
- Return type:
Array
-
nabla.nn.leaky_relu(x, negative_slope=0.01)[source]
Leaky ReLU activation function.
- Parameters:
-
- Returns:
Array with Leaky ReLU applied element-wise
- Return type:
Array
-
nabla.nn.sigmoid(x)[source]
Sigmoid activation function.
- Parameters:
x (Array) – Input array
- Returns:
Array with sigmoid applied element-wise
- Return type:
Array
-
nabla.nn.tanh(x)[source]
Hyperbolic tangent activation function.
- Parameters:
x (Array) – Input array
- Returns:
Array with tanh applied element-wise
- Return type:
Array
-
nabla.nn.gelu(x)[source]
Gaussian Error Linear Unit activation function.
GELU(x) = x * Φ(x) where Φ(x) is the CDF of standard normal distribution.
Approximation: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
- Parameters:
x (Array) – Input array
- Returns:
Array with GELU applied element-wise
- Return type:
Array
-
nabla.nn.silu(x)[source]
Sigmoid Linear Unit (SiLU) activation function.
SiLU(x) = x * sigmoid(x) = Swish(x, β=1)
- Parameters:
x (Array) – Input array
- Returns:
Array with SiLU applied element-wise
- Return type:
Array
-
nabla.nn.swish(x, beta=1.0)[source]
Swish (SiLU) activation function.
Swish(x) = x * sigmoid(β * x)
When β = 1, this is SiLU (Sigmoid Linear Unit).
- Parameters:
-
- Returns:
Array with Swish applied element-wise
- Return type:
Array
-
nabla.nn.softmax(x, axis=-1)[source]
Softmax activation function.
- Parameters:
-
- Returns:
Array with softmax applied along specified axis
- Return type:
Array
-
nabla.nn.log_softmax(x, axis=-1)[source]
Log-softmax activation function.
- Parameters:
-
- Returns:
Array with log-softmax applied along specified axis
- Return type:
Array
-
nabla.nn.create_mlp_config(layers, activation='relu', final_activation=None, init_method='he_normal', seed=42)[source]
Create MLP configuration dictionary.
- Parameters:
layers (list[int]) – List of layer sizes [input, hidden1, hidden2, …, output]
activation (str) – Activation function for hidden layers
final_activation (str | None) – Optional activation for final layer
init_method (str) – Weight initialization method
seed (int) – Random seed for reproducibility
- Returns:
Configuration dictionary with params and forward function
- Return type:
dict
-
class nabla.nn.MLPBuilder[source]
Bases: object
Builder class for creating MLP configurations.
-
__init__()[source]
-
with_layers(layers)[source]
Set layer sizes.
-
with_activation(activation)[source]
Set hidden layer activation function.
-
with_final_activation(activation)[source]
Set final layer activation function.
-
with_init_method(method)[source]
Set weight initialization method.
-
with_seed(seed)[source]
Set random seed.
-
build()[source]
Build the MLP configuration.
-
nabla.nn.create_sin_dataset(batch_size=256, sin_periods=8)[source]
Create the complex 8-period sin dataset from mlp_train_jit.py.
- Parameters:
batch_size (int) – Number of samples to generate
sin_periods (int) – Number of sin periods in [0, 1] interval
- Returns:
Tuple of (x, targets) where targets are sin function values
- Return type:
tuple[Array, Array]
-
nabla.nn.create_dataset(batch_size, input_dim, seed=None)[source]
Create a simple random dataset for testing.
- Parameters:
batch_size (int) – Number of samples
input_dim (int) – Input dimension
seed (int | None) – Random seed for reproducibility
- Returns:
Tuple of (inputs, targets)
- Return type:
tuple[Array, Array]
-
nabla.nn.accuracy(predictions, targets)[source]
Compute classification accuracy.
- Parameters:
predictions (Array) – Model predictions - either logits/probabilities [batch_size, num_classes]
or class indices [batch_size]
targets (Array) – True labels - either one-hot [batch_size, num_classes] or indices [batch_size]
- Returns:
Scalar accuracy value between 0 and 1
- Return type:
Array
-
nabla.nn.precision(predictions, targets, num_classes, class_idx=0)[source]
Compute precision for a specific class.
Precision = TP / (TP + FP)
- Parameters:
predictions (Array) – Model predictions (logits) [batch_size, num_classes]
targets (Array) – True labels (sparse) [batch_size]
num_classes (int) – Total number of classes
class_idx (int) – Class index to compute precision for
- Returns:
Scalar precision value for the specified class
- Return type:
Array
-
nabla.nn.recall(predictions, targets, num_classes, class_idx=0)[source]
Compute recall for a specific class.
Recall = TP / (TP + FN)
- Parameters:
predictions (Array) – Model predictions (logits) [batch_size, num_classes]
targets (Array) – True labels (sparse) [batch_size]
num_classes (int) – Total number of classes
class_idx (int) – Class index to compute recall for
- Returns:
Scalar recall value for the specified class
- Return type:
Array
-
nabla.nn.f1_score(predictions, targets, num_classes, class_idx=0)[source]
Compute F1 score for a specific class.
F1 = 2 * (precision * recall) / (precision + recall)
- Parameters:
predictions (Array) – Model predictions (logits) [batch_size, num_classes]
targets (Array) – True labels (sparse) [batch_size]
num_classes (int) – Total number of classes
class_idx (int) – Class index to compute F1 score for
- Returns:
Scalar F1 score for the specified class
- Return type:
Array
-
nabla.nn.mean_squared_error_metric(predictions, targets)[source]
Compute MSE metric for regression tasks.
- Parameters:
predictions (Array) – Model predictions [batch_size, …]
targets (Array) – True targets [batch_size, …]
- Returns:
Scalar MSE value
- Return type:
Array
-
nabla.nn.mean_absolute_error_metric(predictions, targets)[source]
Compute MAE metric for regression tasks.
- Parameters:
predictions (Array) – Model predictions [batch_size, …]
targets (Array) – True targets [batch_size, …]
- Returns:
Scalar MAE value
- Return type:
Array
-
nabla.nn.r_squared(predictions, targets)[source]
Compute R-squared (coefficient of determination) for regression tasks.
R² = 1 - (SS_res / SS_tot)
where SS_res = Σ(y_true - y_pred)² and SS_tot = Σ(y_true - y_mean)²
- Parameters:
predictions (Array) – Model predictions [batch_size, …]
targets (Array) – True targets [batch_size, …]
- Returns:
Scalar R² value
- Return type:
Array
-
nabla.nn.pearson_correlation(predictions, targets)[source]
Compute Pearson correlation coefficient.
- Parameters:
predictions (Array) – Model predictions [batch_size, …]
targets (Array) – True targets [batch_size, …]
- Returns:
Scalar correlation coefficient
- Return type:
Array
-
nabla.nn.dropout(x, p=0.5, training=True, seed=None)[source]
Apply dropout regularization.
During training, randomly sets elements to zero with probability p.
During inference, scales all elements by (1-p) to maintain expected values.
- Parameters:
x (Array) – Input array
p (float) – Dropout probability (fraction of elements to set to zero)
training (bool) – Whether in training mode (apply dropout) or inference mode
seed (int | None) – Random seed for reproducibility
- Returns:
Array with dropout applied
- Return type:
Array
-
nabla.nn.l1_regularization(params, weight=0.01)[source]
Compute L1 (Lasso) regularization loss.
L1 regularization adds a penalty equal to the sum of absolute values of parameters.
This encourages sparsity in the model parameters.
- Parameters:
-
- Returns:
Scalar L1 regularization loss
- Return type:
Array
-
nabla.nn.l2_regularization(params, weight=0.01)[source]
Compute L2 (Ridge) regularization loss.
L2 regularization adds a penalty equal to the sum of squares of parameters.
This encourages small parameter values and helps prevent overfitting.
- Parameters:
-
- Returns:
Scalar L2 regularization loss
- Return type:
Array
-
nabla.nn.elastic_net_regularization(params, l1_weight=0.01, l2_weight=0.01, l1_ratio=0.5)[source]
Compute Elastic Net regularization loss.
Elastic Net combines L1 and L2 regularization:
ElasticNet = l1_ratio * L1 + (1 - l1_ratio) * L2
- Parameters:
params (list[Array]) – List of parameter arrays (typically weights)
l1_weight (float) – L1 regularization strength
l2_weight (float) – L2 regularization strength
l1_ratio (float) – Ratio of L1 to L2 regularization (0 = pure L2, 1 = pure L1)
- Returns:
Scalar Elastic Net regularization loss
- Return type:
Array
-
nabla.nn.gradient_clipping(gradients, max_norm=1.0, norm_type='l2')[source]
Apply gradient clipping to prevent exploding gradients.
- Parameters:
gradients (list[Array]) – List of gradient arrays
max_norm (float) – Maximum allowed gradient norm
norm_type (str) – Type of norm to use (“l2” or “l1”)
- Returns:
Tuple of (clipped_gradients, total_norm)
- Return type:
tuple[list[Array], Array]