# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Special functions for neural networks."""
from collections.abc import Callable
from ..core.array import Array
# Public API
__all__ = ["softmax", "logsumexp", "where", "cond"]
[docs]
def logsumexp(arg: Array, axis: int | None = None, keep_dims: bool = False) -> Array:
"""Compute log(sum(exp(x))) in a numerically stable way.
Args:
arg: Input array
axis: Axis along which to compute logsumexp. If None, compute over all elements.
keep_dims: Whether to keep reduced dimensions
Returns:
Array containing logsumexp values
"""
from .binary import add, sub
from .reduce import max as array_max
from .reduce import sum as array_sum
from .unary import exp, log
# For numerical stability, subtract the max before exponentiating
# logsumexp(x) = max(x) + log(sum(exp(x - max(x))))
# Find max along specified axis, keeping dimensions for broadcasting
x_max = array_max(arg, axes=axis, keep_dims=True)
# Subtract max and exponentiate
shifted = sub(arg, x_max)
exp_shifted = exp(shifted)
# Sum and take log
sum_exp = array_sum(exp_shifted, axes=axis, keep_dims=True)
log_sum_exp = log(sum_exp)
# Add back the max
result = add(x_max, log_sum_exp)
# Remove extra dimensions if not keeping them
if not keep_dims and axis is not None:
from .view import squeeze
axes_to_squeeze = [axis] if isinstance(axis, int) else list(axis)
for ax in sorted(axes_to_squeeze, reverse=True):
result = squeeze(result, [ax]) # Pass as list
return result
[docs]
def softmax(arg: Array, axis: int = -1) -> Array:
"""Compute softmax function in a numerically stable way.
Args:
arg: Input array
axis: Axis along which to compute softmax
Returns:
Array containing softmax probabilities
"""
from .binary import sub
from .unary import exp
# For numerical stability: softmax(x) = exp(x - logsumexp(x))
log_sum_exp = logsumexp(arg, axis=axis, keep_dims=True)
# Compute softmax: exp(x - logsumexp(x))
normalized = sub(arg, log_sum_exp)
return exp(normalized)
[docs]
def where(condition: Array, x: Array, y: Array) -> Array:
"""Element-wise selection from x or y based on condition.
Args:
condition: Boolean array for selection
x: Array to select from where condition is True
y: Array to select from where condition is False
Returns:
Array with elements selected from x or y
"""
from .binary import add, mul
from .unary import cast, logical_not
# where(c, x, y) = c * x + (1 - c) * y
# Convert boolean condition to float for arithmetic
cond_float = cast(condition, x.dtype)
inv_cond = cast(logical_not(condition), x.dtype)
x_part = mul(cond_float, x)
y_part = mul(inv_cond, y)
return add(x_part, y_part)
[docs]
def cond(
condition: Array, true_fn: Callable, false_fn: Callable, *args, **kwargs
) -> Array:
"""Conditional execution based on a boolean condition.
Args:
condition: Boolean array determining which function to execute
true_fn: Function to execute if condition is True
false_fn: Function to execute if condition is False
*args, **kwargs: Arguments passed to the selected function
Returns:
Result of the executed function
"""
from max.dtype import DType
from .unary import cast
# Convert condition to boolean if necessary
bool_condition = cast(condition, DType.bool)
return where(bool_condition, true_fn(*args, **kwargs), false_fn(*args, **kwargs))