Training
Training utilities for neural networks.
-
nabla.nn.utils.training.create_dataset(batch_size, input_dim, seed=None)[source]
Create a simple random dataset for testing.
- Parameters:
batch_size (int) – Number of samples
input_dim (int) – Input dimension
seed (int | None) – Random seed for reproducibility
- Returns:
Tuple of (inputs, targets)
- Return type:
tuple[Array, Array]
-
nabla.nn.utils.training.create_sin_dataset(batch_size=256, sin_periods=8)[source]
Create the complex 8-period sin dataset from mlp_train_jit.py.
- Parameters:
batch_size (int) – Number of samples to generate
sin_periods (int) – Number of sin periods in [0, 1] interval
- Returns:
Tuple of (x, targets) where targets are sin function values
- Return type:
tuple[Array, Array]
-
nabla.nn.utils.training.compute_accuracy(predictions, targets, threshold=0.5)[source]
Compute classification accuracy.
- Parameters:
predictions (Array) – Model predictions
targets (Array) – True labels
threshold (float) – Classification threshold
- Returns:
Accuracy as a float between 0 and 1
- Return type:
float
-
nabla.nn.utils.training.compute_correlation(predictions, targets)[source]
Compute Pearson correlation coefficient.
- Parameters:
-
- Returns:
Correlation coefficient as a float
- Return type:
float