Utils#

Neural network utilities and training helpers

Submodule Overview#

nabla.nn.utils.create_dataset(batch_size, input_dim, seed=None)[source]#

Create a simple random dataset for testing.

Parameters:
  • batch_size (int) – Number of samples

  • input_dim (int) – Input dimension

  • seed (int | None) – Random seed for reproducibility

Returns:

Tuple of (inputs, targets)

Return type:

tuple[Array, Array]

nabla.nn.utils.create_sin_dataset(batch_size=256, sin_periods=8)[source]#

Create the complex 8-period sin dataset from mlp_train_jit.py.

Parameters:
  • batch_size (int) – Number of samples to generate

  • sin_periods (int) – Number of sin periods in [0, 1] interval

Returns:

Tuple of (x, targets) where targets are sin function values

Return type:

tuple[Array, Array]

nabla.nn.utils.compute_accuracy(predictions, targets, threshold=0.5)[source]#

Compute classification accuracy.

Parameters:
  • predictions (Array) – Model predictions

  • targets (Array) – True labels

  • threshold (float) – Classification threshold

Returns:

Accuracy as a float between 0 and 1

Return type:

float

nabla.nn.utils.compute_correlation(predictions, targets)[source]#

Compute Pearson correlation coefficient.

Parameters:
  • predictions (Array) – Model predictions

  • targets (Array) – True values

Returns:

Correlation coefficient as a float

Return type:

float

nabla.nn.utils.accuracy(predictions, targets)[source]#

Compute classification accuracy.

Parameters:
  • predictions (Array) – Model predictions - either logits/probabilities [batch_size, num_classes] or class indices [batch_size]

  • targets (Array) – True labels - either one-hot [batch_size, num_classes] or indices [batch_size]

Returns:

Scalar accuracy value between 0 and 1

Return type:

Array

nabla.nn.utils.top_k_accuracy(predictions, targets, k=5)[source]#

Compute top-k classification accuracy.

Parameters:
  • predictions (Array) – Model predictions (logits or probabilities) [batch_size, num_classes]

  • targets (Array) – True labels [batch_size] (sparse format)

  • k (int) – Number of top predictions to consider

Returns:

Scalar top-k accuracy value between 0 and 1

Return type:

Array

nabla.nn.utils.precision(predictions, targets, num_classes, class_idx=0)[source]#

Compute precision for a specific class.

Precision = TP / (TP + FP)

Parameters:
  • predictions (Array) – Model predictions (logits) [batch_size, num_classes]

  • targets (Array) – True labels (sparse) [batch_size]

  • num_classes (int) – Total number of classes

  • class_idx (int) – Class index to compute precision for

Returns:

Scalar precision value for the specified class

Return type:

Array

nabla.nn.utils.recall(predictions, targets, num_classes, class_idx=0)[source]#

Compute recall for a specific class.

Recall = TP / (TP + FN)

Parameters:
  • predictions (Array) – Model predictions (logits) [batch_size, num_classes]

  • targets (Array) – True labels (sparse) [batch_size]

  • num_classes (int) – Total number of classes

  • class_idx (int) – Class index to compute recall for

Returns:

Scalar recall value for the specified class

Return type:

Array

nabla.nn.utils.f1_score(predictions, targets, num_classes, class_idx=0)[source]#

Compute F1 score for a specific class.

F1 = 2 * (precision * recall) / (precision + recall)

Parameters:
  • predictions (Array) – Model predictions (logits) [batch_size, num_classes]

  • targets (Array) – True labels (sparse) [batch_size]

  • num_classes (int) – Total number of classes

  • class_idx (int) – Class index to compute F1 score for

Returns:

Scalar F1 score for the specified class

Return type:

Array

nabla.nn.utils.mean_squared_error_metric(predictions, targets)[source]#

Compute MSE metric for regression tasks.

Parameters:
  • predictions (Array) – Model predictions [batch_size, …]

  • targets (Array) – True targets [batch_size, …]

Returns:

Scalar MSE value

Return type:

Array

nabla.nn.utils.mean_absolute_error_metric(predictions, targets)[source]#

Compute MAE metric for regression tasks.

Parameters:
  • predictions (Array) – Model predictions [batch_size, …]

  • targets (Array) – True targets [batch_size, …]

Returns:

Scalar MAE value

Return type:

Array

nabla.nn.utils.r_squared(predictions, targets)[source]#

Compute R-squared (coefficient of determination) for regression tasks.

R² = 1 - (SS_res / SS_tot) where SS_res = Σ(y_true - y_pred)² and SS_tot = Σ(y_true - y_mean)²

Parameters:
  • predictions (Array) – Model predictions [batch_size, …]

  • targets (Array) – True targets [batch_size, …]

Returns:

Scalar R² value

Return type:

Array

nabla.nn.utils.pearson_correlation(predictions, targets)[source]#

Compute Pearson correlation coefficient.

Parameters:
  • predictions (Array) – Model predictions [batch_size, …]

  • targets (Array) – True targets [batch_size, …]

Returns:

Scalar correlation coefficient

Return type:

Array

nabla.nn.utils.l1_regularization(params, weight=0.01)[source]#

Compute L1 (Lasso) regularization loss.

L1 regularization adds a penalty equal to the sum of absolute values of parameters. This encourages sparsity in the model parameters.

Parameters:
  • params (list[Array]) – List of parameter arrays (typically weights)

  • weight (float) – Regularization strength

Returns:

Scalar L1 regularization loss

Return type:

Array

nabla.nn.utils.l2_regularization(params, weight=0.01)[source]#

Compute L2 (Ridge) regularization loss.

L2 regularization adds a penalty equal to the sum of squares of parameters. This encourages small parameter values and helps prevent overfitting.

Parameters:
  • params (list[Array]) – List of parameter arrays (typically weights)

  • weight (float) – Regularization strength

Returns:

Scalar L2 regularization loss

Return type:

Array

nabla.nn.utils.elastic_net_regularization(params, l1_weight=0.01, l2_weight=0.01, l1_ratio=0.5)[source]#

Compute Elastic Net regularization loss.

Elastic Net combines L1 and L2 regularization: ElasticNet = l1_ratio * L1 + (1 - l1_ratio) * L2

Parameters:
  • params (list[Array]) – List of parameter arrays (typically weights)

  • l1_weight (float) – L1 regularization strength

  • l2_weight (float) – L2 regularization strength

  • l1_ratio (float) – Ratio of L1 to L2 regularization (0 = pure L2, 1 = pure L1)

Returns:

Scalar Elastic Net regularization loss

Return type:

Array

nabla.nn.utils.dropout(x, p=0.5, training=True, seed=None)[source]#

Apply dropout regularization.

During training, randomly sets elements to zero with probability p. During inference, scales all elements by (1-p) to maintain expected values.

Parameters:
  • x (Array) – Input array

  • p (float) – Dropout probability (fraction of elements to set to zero)

  • training (bool) – Whether in training mode (apply dropout) or inference mode

  • seed (int | None) – Random seed for reproducibility

Returns:

Array with dropout applied

Return type:

Array

nabla.nn.utils.spectral_normalization(weight, u=None, n_iterations=1)[source]#

Apply spectral normalization to weight matrix.

Spectral normalization constrains the spectral norm (largest singular value) of weight matrices to be at most 1. This stabilizes training of GANs.

Parameters:
  • weight (Array) – Weight matrix to normalize [out_features, in_features]

  • u (Array | None) – Left singular vector estimate (updated during training)

  • n_iterations (int) – Number of power iterations to approximate largest singular value

Returns:

Tuple of (normalized_weight, updated_u)

Return type:

tuple[Array, Array]

nabla.nn.utils.gradient_clipping(gradients, max_norm=1.0, norm_type='l2')[source]#

Apply gradient clipping to prevent exploding gradients.

Parameters:
  • gradients (list[Array]) – List of gradient arrays

  • max_norm (float) – Maximum allowed gradient norm

  • norm_type (str) – Type of norm to use (“l2” or “l1”)

Returns:

Tuple of (clipped_gradients, total_norm)

Return type:

tuple[list[Array], Array]