Source code for nabla.nn.utils.training

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Training utilities for neural networks."""

import nabla as nb


[docs] def create_dataset( batch_size: int, input_dim: int, seed: int | None = None ) -> tuple[nb.Array, nb.Array]: """Create a simple random dataset for testing. Args: batch_size: Number of samples input_dim: Input dimension seed: Random seed for reproducibility Returns: Tuple of (inputs, targets) """ if seed is not None: import numpy as np np.random.seed(seed) x = nb.rand((batch_size, input_dim), lower=-1.0, upper=1.0, dtype=nb.DType.float32) # Simple target: sum of inputs targets = nb.sum(x, axis=1, keepdims=True) return x, targets
[docs] def create_sin_dataset( batch_size: int = 256, sin_periods: int = 8 ) -> tuple[nb.Array, nb.Array]: """Create the complex 8-period sin dataset from mlp_train_jit.py. Args: batch_size: Number of samples to generate sin_periods: Number of sin periods in [0, 1] interval Returns: Tuple of (x, targets) where targets are sin function values """ import numpy as np x = nb.rand((batch_size, 1), lower=0.0, upper=1.0, dtype=nb.DType.float32) targets = nb.sin(sin_periods * 2.0 * np.pi * x) / 2.0 + 0.5 return x, targets
[docs] def compute_accuracy( predictions: nb.Array, targets: nb.Array, threshold: float = 0.5 ) -> float: """Compute classification accuracy. Args: predictions: Model predictions targets: True labels threshold: Classification threshold Returns: Accuracy as a float between 0 and 1 """ pred_labels = predictions > threshold target_labels = targets > threshold correct = pred_labels == target_labels # Convert boolean to float for mean calculation correct_float = nb.where(correct, 1.0, 0.0) accuracy = nb.mean(correct_float) return accuracy.to_numpy().item()
[docs] def compute_correlation(predictions: nb.Array, targets: nb.Array) -> float: """Compute Pearson correlation coefficient. Args: predictions: Model predictions targets: True values Returns: Correlation coefficient as a float """ import numpy as np pred_np = predictions.to_numpy().flatten() target_np = targets.to_numpy().flatten() correlation = np.corrcoef(pred_np, target_np)[0, 1] return correlation