Source code for nabla.nn.architectures.mlp

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Multi-Layer Perceptron (MLP) architectures."""

from collections.abc import Callable

import nabla as nb

from ..init.variance_scaling import he_normal, initialize_mlp_params, xavier_normal
from ..layers.linear import mlp_forward_with_activations
from ..losses.regression import mean_squared_error


[docs] def create_mlp_forward_and_loss(activation: str = "relu") -> Callable: """Create a combined forward pass and loss computation function. This function factory creates the forward_and_loss function needed for VJP computation in training loops. Args: activation: Activation function for hidden layers Returns: Function that takes inputs and returns loss """ def mlp_forward_and_loss(inputs: list[nb.Array]) -> list[nb.Array]: """Combined forward pass and loss computation for VJP.""" x, targets, *params = inputs predictions = mlp_forward_with_activations(x, params, activation) loss = mean_squared_error(predictions, targets) return [loss] return mlp_forward_and_loss
[docs] def create_mlp_config( layers: list[int], activation: str = "relu", final_activation: str | None = None, init_method: str = "he_normal", seed: int = 42, ) -> dict: """Create MLP configuration dictionary. Args: layers: List of layer sizes [input, hidden1, hidden2, ..., output] activation: Activation function for hidden layers final_activation: Optional activation for final layer init_method: Weight initialization method seed: Random seed for reproducibility Returns: Configuration dictionary with params and forward function """ # Initialize parameters if init_method == "mlp_specialized": # Use the specialized initialization from mlp_train_jit.py params = initialize_mlp_params(layers, seed) elif init_method == "he_normal": params = [] for i in range(len(layers) - 1): w = he_normal((layers[i], layers[i + 1]), seed + i) b = nb.zeros((1, layers[i + 1])) params.extend([w, b]) elif init_method == "xavier_normal": params = [] for i in range(len(layers) - 1): w = xavier_normal((layers[i], layers[i + 1]), seed + i) b = nb.zeros((1, layers[i + 1])) params.extend([w, b]) else: raise ValueError(f"Unsupported init_method: {init_method}") # Create forward function def forward_fn(x: nb.Array, params: list[nb.Array]) -> nb.Array: return mlp_forward_with_activations(x, params, activation, final_activation) # Create forward and loss function for training forward_and_loss_fn = create_mlp_forward_and_loss(activation) return { "params": params, "forward": forward_fn, "forward_and_loss": forward_and_loss_fn, "layers": layers, "activation": activation, "final_activation": final_activation, "init_method": init_method, }
[docs] class MLPBuilder: """Builder class for creating MLP configurations."""
[docs] def __init__(self): self.layers = None self.activation = "relu" self.final_activation = None self.init_method = "he_normal" self.seed = 42
[docs] def with_layers(self, layers: list[int]) -> "MLPBuilder": """Set layer sizes.""" self.layers = layers return self
[docs] def with_activation(self, activation: str) -> "MLPBuilder": """Set hidden layer activation function.""" self.activation = activation return self
[docs] def with_final_activation(self, activation: str | None) -> "MLPBuilder": """Set final layer activation function.""" self.final_activation = activation return self
[docs] def with_init_method(self, method: str) -> "MLPBuilder": """Set weight initialization method.""" self.init_method = method return self
[docs] def with_seed(self, seed: int) -> "MLPBuilder": """Set random seed.""" self.seed = seed return self
[docs] def build(self) -> dict: """Build the MLP configuration.""" if self.layers is None: raise ValueError("Must specify layers with .with_layers()") return create_mlp_config( self.layers, self.activation, self.final_activation, self.init_method, self.seed, )