Source code for nabla.ops.reduce

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Reduction operations."""

from __future__ import annotations

import numpy as np
from max.dtype import DType
from max.graph import TensorValue, ops

from ..core.array import Array, Shape
from .operation import ReductionOperation
from .view import squeeze, squeeze_batch_dims

# Public API
__all__ = ["sum", "sum_batch_dims", "mean", "max", "argmax"]


def _normalize_axes(
    axes: int | list[int] | tuple[int, ...] | None, ndim: int
) -> list[int]:
    """Normalize axes parameter to a list of integers."""
    if axes is None:
        return list(range(ndim))
    elif isinstance(axes, int):
        return [axes]
    elif isinstance(axes, (list, tuple)):
        return list(axes)
    else:
        raise TypeError(f"axes must be int, list, tuple, or None, got {type(axes)}")


class SumOp(ReductionOperation):
    """sum reduction operation."""

    def __init__(
        self,
        arg_shape: Shape,
        axes: int | list[int] | tuple[int, ...] | None = None,
        keep_dims: bool = False,
    ):
        super().__init__(f"sum[axes={axes}]", axes, keep_dims=True)
        self.arg_shape = arg_shape
        self.axes = axes
        self.keep_dims = keep_dims

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output_symbol = args[0]

        # Normalize axes to handle None, int, or collections
        normalized_axes = _normalize_axes(self.axes, len(args[0].shape))

        for axis in normalized_axes:
            output_symbol = ops.sum(output_symbol, axis=axis)

        output.tensor_value = output_symbol

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        if isinstance(self.axes, list):
            numpy_axes: int | tuple[int, ...] | None = tuple(self.axes)
        else:
            numpy_axes = self.axes

        np_result = np.sum(args[0].to_numpy(), axis=numpy_axes, keepdims=True)
        if np_result.ndim == 0:
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        if len(cotangent.shape) > len(primals[0].shape):
            return [cotangent]

        if output.shape != cotangent.shape:
            raise ValueError(
                f"In VJP rule for SumOp, "
                f"output shape {output.shape} "
                f"does not match cotangent shape {cotangent.shape}."
                f"primal shape: {primals[0].shape}, "
            )

        from .view import broadcast_to

        return [broadcast_to(cotangent, self.arg_shape)]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        return sum(tangents[0], axes=self.axes, keep_dims=True)


# noqa: A001 - Intentionally shadowing built-in 'sum' for API consistency
[docs] def sum( arg: Array, axes: int | list[int] | tuple[int, ...] | None = None, keep_dims: bool = False, ) -> Array: """sum array elements over given axes.""" if axes is not None: if isinstance(axes, int): axes = [axes] elif isinstance(axes, list | tuple): axes = [int(axis) for axis in axes] ndim = len(arg.shape) for axis in axes: if not -ndim <= axis < ndim: raise ValueError( f"axis {axis} is out of bounds for array of dimension {ndim}" ) axes = [axis if axis < 0 else axis - len(arg.shape) for axis in axes] else: axes = [] for i in range(-len(arg.shape), 0): axes.append(i) axes = sorted(axes) op = SumOp(arg.shape, axes, keep_dims=keep_dims) res = op.forward(arg) if not keep_dims: # manually use the squeeze operation to squeeze remaining axes for axis in axes: res = squeeze(res, [axis]) # axes always negative return res
[docs] def mean( arg: Array, axes: int | list[int] | tuple[int, ...] | None = None, keep_dims: bool = False, ) -> Array: """Compute mean of array elements over given axes.""" from .binary import div from .creation import array # First compute the sum sum_result = sum(arg, axes=axes, keep_dims=keep_dims) # Calculate the number of elements being averaged if axes is not None: if isinstance(axes, int): axes = [axes] elif isinstance(axes, list | tuple): axes = [int(axis) for axis in axes] # Handle negative axes ndim = len(arg.shape) normalized_axes = [] for axis in axes: if not -ndim <= axis < ndim: raise ValueError( f"axis {axis} is out of bounds for array of dimension {ndim}" ) if axis < 0: normalized_axes.append(len(arg.shape) + axis) else: normalized_axes.append(axis) # Count elements along reduced axes count = 1 for axis in normalized_axes: if axis < len(arg.shape): count *= arg.shape[axis] else: # All axes - total number of elements count = 1 for dim in arg.shape: count *= dim # Create count as a scalar array count_array = array(float(count), dtype=arg.dtype) # Divide sum by count return div(sum_result, count_array)
class SumBatchDimsOp(ReductionOperation): """sum reduction operation.""" def __init__( self, arg_batch_dims: Shape, axes: int | list[int] | tuple[int, ...] | None = None, keep_dims: bool = False, ): super().__init__(f"sum_batch_dims[axes={axes}]") self.arg_batch_dims = arg_batch_dims self.axes = axes self.keep_dims = keep_dims def compute_output_shape(self, *input_shapes): return input_shapes[0] def compute_output_batch_dims(self, *input_batch_dims): return self._compute_reduction_shape(input_batch_dims[0], self.axes) def maxpr(self, args: list[TensorValue], output: Array) -> None: # first we must subtract len(output.shape) from each axis value normalized_axes = _normalize_axes(self.axes, len(args[0].shape)) axes = [ax - len(output.shape) for ax in normalized_axes] output_symbol = args[0] for axis in axes: output_symbol = ops.sum(output_symbol, axis=axis) output.tensor_value = output_symbol def eagerxpr(self, args: list[Array], output: Array) -> None: normalized_axes = _normalize_axes(self.axes, len(args[0].shape)) axes = [ax - len(output.shape) for ax in normalized_axes] np_result = np.sum( args[0].to_numpy(), axis=tuple(axes) if axes else None, keepdims=True ) if np_result.ndim == 0: np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .view import broadcast_batch_dims if len(cotangent.batch_dims) > len(primals[0].batch_dims): return [cotangent] if output.batch_dims != cotangent.batch_dims: raise ValueError( f"In VJP rule for SumBatchDimsOp, " f"output batch_dims {output.batch_dims} " f"do not match cotangent batch_dims {cotangent.batch_dims}." f"primal batch_dims: {primals[0].batch_dims}" ) return [broadcast_batch_dims(cotangent, self.arg_batch_dims)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: return sum_batch_dims(tangents[0], axes=self.axes, keep_dims=True)
[docs] def sum_batch_dims( arg: Array, axes: int | list[int] | tuple[int, ...] | None = None, keep_dims: bool = False, ) -> Array: """sum array elements over given batch dimension axes.""" if axes is not None: if isinstance(axes, int): axes = [axes] elif isinstance(axes, list | tuple): axes = [int(axis) for axis in axes] batch_dims_len = len(arg.batch_dims) for axis in axes: if not -batch_dims_len <= axis < batch_dims_len: raise ValueError( f"axis {axis} is out of bounds for array with " f"{batch_dims_len} batch dimensions" ) axes = [axis if axis < 0 else axis - batch_dims_len for axis in axes] else: axes = [] for i in range(-len(arg.batch_dims), 0): axes.append(i) axes = sorted(axes) op = SumBatchDimsOp(arg.batch_dims, axes, keep_dims) res = op.forward(arg) if not keep_dims: for axis in axes: res = squeeze_batch_dims(res, [axis]) return res
class MaxOp(ReductionOperation): """Max reduction operation.""" def __init__( self, arg_shape: Shape, axes: int | list[int] | tuple[int, ...] | None = None, keep_dims: bool = False, ): super().__init__(f"max[axes={axes}]", axes, keep_dims=True) self.arg_shape = arg_shape self.axes = axes self.keep_dims = keep_dims def maxpr(self, args: list[TensorValue], output: Array) -> None: output_symbol = args[0] # Normalize axes to handle None, int, or collections normalized_axes = _normalize_axes(self.axes, len(args[0].shape)) for axis in normalized_axes: output_symbol = ops.max(output_symbol, axis=axis) output.tensor_value = output_symbol def eagerxpr(self, args: list[Array], output: Array) -> None: if isinstance(self.axes, list): numpy_axes: int | tuple[int, ...] | None = tuple(self.axes) else: numpy_axes = self.axes np_result = np.max(args[0].to_numpy(), axis=numpy_axes, keepdims=True) if np_result.ndim == 0: np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import equal from .view import broadcast_to # Get the primal input primal = primals[0] # Broadcast cotangent to match primal shape cotangent_broadcasted = broadcast_to(cotangent, self.arg_shape) # Broadcast the output (max values) to match primal shape output_broadcasted = broadcast_to(output, self.arg_shape) # Create mask where primal equals the max value (output) mask = equal(primal, output_broadcasted) # Convert mask to float and multiply with broadcasted cotangent mask_float = mask.astype(primal.dtype) result = cotangent_broadcasted * mask_float return [result] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import equal, mul from .view import broadcast_to # Create mask where input equals the max value primal = primals[0] max_result = max(primal, axes=self.axes, keep_dims=True) max_broadcasted = broadcast_to(max_result, self.arg_shape) mask = equal(primal, max_broadcasted) # Convert mask to float for arithmetic operations mask_float = mask.astype(primal.dtype) # Apply mask to tangents and sum over the reduced axes masked_tangents = mul(tangents[0], mask_float) return sum(masked_tangents, axes=self.axes, keep_dims=True)
[docs] def max( arg: Array, axes: int | list[int] | tuple[int, ...] | None = None, keep_dims: bool = False, ) -> Array: """Find maximum array elements over given axes.""" if axes is not None: if isinstance(axes, int): axes = [axes] elif isinstance(axes, list | tuple): axes = [int(axis) for axis in axes] ndim = len(arg.shape) for axis in axes: if not -ndim <= axis < ndim: raise ValueError( f"axis {axis} is out of bounds for array of dimension {ndim}" ) axes = [axis if axis < 0 else axis - len(arg.shape) for axis in axes] else: axes = [] for i in range(-len(arg.shape), 0): axes.append(i) axes = sorted(axes) op = MaxOp(arg.shape, axes, keep_dims=keep_dims) res = op.forward(arg) if not keep_dims: # manually use the squeeze operation to squeeze remaining axes for axis in axes: res = squeeze(res, [axis]) # axes always negative return res
class ArgMaxOp(ReductionOperation): """ ArgMax reduction operation. It is batch-aware and handles the physical axis. This Op internally behaves as if keep_dims=True. """ def __init__( self, arg_shape: Shape, logical_axis: int | None, ): super().__init__( f"argmax[axis={logical_axis}]", axes=[logical_axis] if logical_axis is not None else None, keep_dims=True, ) self.arg_shape = arg_shape self.logical_axis = logical_axis def compute_output_dtype(self, arg: Array) -> DType: return DType.int64 def maxpr(self, args: list[TensorValue], output: Array) -> None: input_symbol = args[0] # physical_axis = self._get_physical_axis(output.batch_dims) if self.logical_axis is None: # Flatten everything except batch dims for reduction tmp_shape = output.batch_dims + (-1,) tmp_arg = ops.reshape(input_symbol, tmp_shape) result = ops.argmax(tmp_arg, axis=-1) res_shape = output.batch_dims + (1,) * len(self.arg_shape) output.tensor_value = ops.reshape(result, res_shape) else: # Assume that logical axes is always negative output.tensor_value = ops.argmax(input_symbol, axis=self.logical_axis) def eagerxpr(self, args: list[Array], output: Array) -> None: primal = args[0].to_numpy() if self.logical_axis is None: tmp_shape = output.batch_dims + (-1,) tmp_arg = primal.reshape(tmp_shape) np_result = np.argmax(tmp_arg, axis=-1) res_shape = output.batch_dims + (1,) * len(self.arg_shape) res = np_result.reshape(res_shape) if res.ndim == 0: res = np.array(res) output.impl_(res) else: # Assume that logical axes is always negative res = np.argmax(primal, axis=self.logical_axis, keepdims=True) if res.ndim == 0: res = np.array(res) output.impl_(res) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .creation import zeros_like return [zeros_like(primals[0])] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .creation import zeros_like return zeros_like(output)
[docs] def argmax( arg: Array, axes: int | None = None, keep_dims: bool = False, ) -> Array: """ Find indices of maximum array elements over a given axis, matching JAX's API. """ logical_axis: int | None ndim = len(arg.shape) # 1. Validate the user-provided 'axes' argument if axes is None: logical_axis = None elif isinstance(axes, int): if not -ndim <= axes < ndim: raise ValueError( f"axis {axes} is out of bounds for array of dimension {ndim}" ) logical_axis = axes elif isinstance(axes, (list, tuple)): if len(axes) > 1: raise NotImplementedError("nabla.argmax does not support a tuple of axes.") if not axes: raise ValueError("axis must be an integer or None, not an empty sequence.") axis_val = axes[0] if not isinstance(axis_val, int): raise TypeError( f"axis must be an integer, but got {type(axis_val)} inside sequence." ) if not -ndim <= axis_val < ndim: raise ValueError( f"axis {axis_val} is out of bounds for array of dimension {ndim}" ) logical_axis = axis_val else: raise TypeError(f"Invalid type for axes: {type(axes)}") if arg.shape == () or np.prod(arg.shape) == 1: from .creation import zeros_like return zeros_like(arg).astype(DType.int64) # make axes always a negative value if logical_axis is not None and logical_axis >= 0: logical_axis = logical_axis - ndim if logical_axis is not None and arg.stage_realization: # If we are in JIT mode, we need to move the axis to the back from .view import move_axis_from_back, move_axis_to_back arg = move_axis_to_back(arg, logical_axis) op = ArgMaxOp(arg.shape, -1) res = op.forward(arg) # move the axis back to its original position res = move_axis_from_back(res, logical_axis) else: # If axes is None, we can directly use the ArgMaxOp with the original axis op = ArgMaxOp(arg.shape, logical_axis) res = op.forward(arg) # 3. Handle keep_dims if not keep_dims: if logical_axis is None: return res.reshape(()) else: # Squeeze the original logical axis relative to the tensor's logical shape. squeeze_axis = logical_axis if logical_axis >= 0 else ndim + logical_axis res = squeeze(res, [squeeze_axis]) return res