Source code for nabla.nn.losses.classification
# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Classification loss functions."""
import numpy as np
import nabla as nb
[docs]
def cross_entropy_loss(logits: nb.Array, targets: nb.Array, axis: int = -1) -> nb.Array:
"""Compute cross-entropy loss between logits and targets.
Args:
logits: Raw model outputs (before softmax) [batch_size, num_classes]
targets: One-hot encoded targets [batch_size, num_classes]
axis: Axis along which to compute softmax
Returns:
Scalar loss value
"""
from ...ops.binary import mul
from ...ops.reduce import sum as array_sum
from ...ops.special import logsumexp
# Compute log probabilities using logsumexp for numerical stability
# log_softmax(x) = x - logsumexp(x)
log_sum_exp = logsumexp(logits, axis=axis, keep_dims=True)
log_probs = logits - log_sum_exp
# Cross-entropy: -sum(targets * log_probs)
cross_entropy = -array_sum(mul(targets, log_probs), axes=axis)
# Average over batch
batch_size = nb.array([np.float32(logits.shape[0])])
return array_sum(cross_entropy) / batch_size
[docs]
def sparse_cross_entropy_loss(
logits: nb.Array, targets: nb.Array, axis: int = -1
) -> nb.Array:
"""Compute cross-entropy loss with integer targets.
Args:
logits: Raw model outputs [batch_size, num_classes]
targets: Integer class indices [batch_size]
axis: Axis along which to compute softmax
Returns:
Scalar loss value
"""
# Convert targets to one-hot encoding
num_classes = logits.shape[axis]
batch_size = targets.shape[0]
# Create one-hot encoding
targets_np = targets.to_numpy().astype(np.int32)
one_hot_np = np.zeros((batch_size, num_classes), dtype=np.float32)
one_hot_np[np.arange(batch_size), targets_np] = 1.0
one_hot_targets = nb.Array.from_numpy(one_hot_np)
return cross_entropy_loss(logits, one_hot_targets, axis=axis)
[docs]
def binary_cross_entropy_loss(
predictions: nb.Array, targets: nb.Array, eps: float = 1e-7
) -> nb.Array:
"""Compute binary cross-entropy loss.
Args:
predictions: Model predictions (after sigmoid) [batch_size]
targets: Binary targets (0 or 1) [batch_size]
eps: Small constant for numerical stability
Returns:
Scalar loss value
"""
from ...ops.binary import mul, sub
from ...ops.creation import full_like
from ...ops.reduce import mean
from ...ops.unary import log
# Clamp predictions to avoid log(0)
eps_tensor = full_like(predictions, eps)
one_minus_eps = full_like(predictions, 1.0 - eps)
# predictions = clamp(predictions, eps, 1-eps)
predictions_clamped = nb.maximum(predictions, eps_tensor)
predictions_clamped = nb.minimum(predictions_clamped, one_minus_eps)
# BCE = -[y*log(p) + (1-y)*log(1-p)]
log_p = log(predictions_clamped)
log_one_minus_p = log(sub(nb.ones_like(predictions_clamped), predictions_clamped))
# Compute binary cross-entropy
bce_per_sample = -(
mul(targets, log_p) + mul(sub(nb.ones_like(targets), targets), log_one_minus_p)
)
# Average over batch
return mean(bce_per_sample)
[docs]
def softmax_cross_entropy_loss(
logits: nb.Array, targets: nb.Array, axis: int = -1
) -> nb.Array:
"""Compute softmax cross-entropy loss (numerically stable).
This is equivalent to cross_entropy_loss but more numerically stable
by combining softmax and cross-entropy computations.
Args:
logits: Raw model outputs [batch_size, num_classes]
targets: One-hot encoded targets [batch_size, num_classes]
axis: Axis along which to compute softmax
Returns:
Scalar loss value
"""
from ...ops.binary import mul
from ...ops.reduce import mean
from ...ops.reduce import sum as array_sum
from ...ops.special import logsumexp
# Compute log_softmax = logits - logsumexp(logits)
log_sum_exp = logsumexp(logits, axis=axis, keep_dims=True)
log_softmax = logits - log_sum_exp
# Cross-entropy with log_softmax
cross_entropy = -array_sum(mul(targets, log_softmax), axes=axis)
# Average over batch
return mean(cross_entropy)
__all__ = [
"cross_entropy_loss",
"sparse_cross_entropy_loss",
"binary_cross_entropy_loss",
"softmax_cross_entropy_loss",
]