Training#

Training utilities for neural networks.

nabla.nn.utils.training.create_dataset(batch_size, input_dim, seed=None)[source]#

Create a simple random dataset for testing.

Parameters:
  • batch_size (int) – Number of samples

  • input_dim (int) – Input dimension

  • seed (int | None) – Random seed for reproducibility

Returns:

Tuple of (inputs, targets)

Return type:

tuple[Array, Array]

nabla.nn.utils.training.create_sin_dataset(batch_size=256, sin_periods=8)[source]#

Create the complex 8-period sin dataset from mlp_train_jit.py.

Parameters:
  • batch_size (int) – Number of samples to generate

  • sin_periods (int) – Number of sin periods in [0, 1] interval

Returns:

Tuple of (x, targets) where targets are sin function values

Return type:

tuple[Array, Array]

nabla.nn.utils.training.compute_accuracy(predictions, targets, threshold=0.5)[source]#

Compute classification accuracy.

Parameters:
  • predictions (Array) – Model predictions

  • targets (Array) – True labels

  • threshold (float) – Classification threshold

Returns:

Accuracy as a float between 0 and 1

Return type:

float

nabla.nn.utils.training.compute_correlation(predictions, targets)[source]#

Compute Pearson correlation coefficient.

Parameters:
  • predictions (Array) – Model predictions

  • targets (Array) – True values

Returns:

Correlation coefficient as a float

Return type:

float