Source code for nabla.ops.operation

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Base operation classes for a clean OOP design."""

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional

from max.dtype import DType
from max.graph import TensorValue

from ..core.array import Array


[docs] class Operation(ABC): """Abstract base class for all operations."""
[docs] def __init__(self, name: str): self.name = name
[docs] @abstractmethod def forward(self, *args: Array) -> Array: """Forward pass - creates the result Array.""" pass
[docs] @abstractmethod def compute_output_shape(self, *input_shapes: tuple) -> tuple: """Compute the output shape given input shapes.""" pass
[docs] @abstractmethod def maxpr(self, args: list[TensorValue], output: Array) -> None: """MAX graph computation.""" pass
[docs] @abstractmethod def eagerxpr(self, args: list[Array], output: Array) -> None: """Eager computation using NumPy.""" pass
[docs] @abstractmethod def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: """Vector-Jacobian product rule for reverse-mode autodiff.""" pass
[docs] @abstractmethod def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: """Jacobian-vector product rule for forward-mode autodiff.""" pass
[docs] def custom_kernel_path(self) -> Optional[Path]: """Optional: path to custom kernel implementation.""" return None
[docs] class UnaryOperation(Operation): """Base class for unary operations."""
[docs] def forward(self, *args: Array) -> Array: """Forward pass for unary operations.""" if len(args) != 1: raise ValueError(f"Unary operation requires 1 argument, got {len(args)}") arg = args[0] output_shape = self.compute_output_shape(arg.shape) output_batch_dims = self.compute_output_batch_dims(arg.batch_dims) output_dtype = self.compute_output_dtype(arg) res = Array( shape=output_shape, dtype=output_dtype, device=arg.device, materialize=False, name=self.name, batch_dims=output_batch_dims, ) res.set_maxpr(self.maxpr) res.add_arguments(arg) res.vjp_rule = self.vjp_rule res.jvp_rule = self.jvp_rule res.custom_kernel_path = self.custom_kernel_path() if not res.stage_realization: self.eagerxpr([arg], res) return res
[docs] def compute_output_shape(self, *input_shapes: tuple) -> tuple: """Default: output shape same as input shape.""" if len(input_shapes) != 1: raise ValueError( f"Unary operation requires 1 input shape, got {len(input_shapes)}" ) return input_shapes[0]
[docs] def compute_output_dtype(self, arg: Array) -> DType: """Default: output dtype same as input dtype.""" return arg.dtype
[docs] def compute_output_batch_dims( self, input_batch_dims: tuple[int, ...] ) -> tuple[int, ...]: """Default: output batch dims same as input batch dims.""" return input_batch_dims
[docs] def move_to_best_device(*args: Array) -> tuple[Array, ...]: """Move all arrays to the best available device.""" if len(args) <= 1: return args import numpy as np # Track devices and data amounts device_data = {} accelerator_devices = set() for arg in args: device = arg.device data_size = np.prod(arg.shape) device_data[device] = device_data.get(device, 0) + data_size # Check if this device is an accelerator (non-host device) if not device.is_host: accelerator_devices.add(device) # Determine best device according to the rules: # 1. If any accelerator has data, choose the best accelerator considering peer access # 2. Otherwise, choose the device (CPU) with most data if accelerator_devices: # For multi-accelerator scenarios, consider peer access costs if len(accelerator_devices) > 1: # Calculate effective data amount considering peer access accelerator_scores = {} for candidate_device in accelerator_devices: # Base score is the data already on this device base_score = device_data[candidate_device] # Add bonus for data that can be directly accessed from other accelerators peer_accessible_data = 0 for other_device in accelerator_devices: if other_device != candidate_device and candidate_device.can_access( other_device ): peer_accessible_data += device_data[other_device] # Weight peer-accessible data less than local data (avoid unnecessary moves) accelerator_scores[candidate_device] = base_score + ( peer_accessible_data * 0.1 ) best_device = max(accelerator_scores, key=lambda d: accelerator_scores[d]) else: # Single accelerator case - simple selection best_device = max(accelerator_devices, key=lambda d: device_data[d]) else: # Find device with most data (will be CPU in this case) best_device = max(device_data, key=lambda d: device_data[d]) # Move all arrays to the best device result_args = [] for arg in args: if arg.device != best_device: result_args.append(arg.to(best_device)) else: result_args.append(arg) return tuple(result_args)
[docs] class BinaryOperation(Operation): """Base class for binary operations."""
[docs] def forward(self, *args: Array) -> Array: """Forward pass for binary operations.""" if len(args) != 2: raise ValueError(f"Binary operation requires 2 arguments, got {len(args)}") # Move arrays to best device args = move_to_best_device(*args) arg1, arg2 = args[0], args[1] from ..ops.view import broadcast_batch_dims, broadcast_to, unsqueeze self._validate_inputs(arg1, arg2) output_shape = self.compute_output_shape(arg1.shape, arg2.shape) output_batch_dims = self.compute_output_batch_dims( arg1.batch_dims, arg2.batch_dims ) output_dtype = self.compute_output_dtype(arg1, arg2) # TODO: The following makes everything a tiny bit slower, can we optiize this unsqueezing approach to make shapes of equal length? if len(arg1.shape) < len(output_shape): for _ in range(len(output_shape) - len(arg1.shape)): arg1 = unsqueeze(arg1, [-len(arg1.shape) - 1]) if len(arg2.shape) < len(output_shape): for _ in range(len(output_shape) - len(arg2.shape)): arg2 = unsqueeze(arg2, [-len(arg2.shape) - 1]) if arg1.traced: arg1 = broadcast_to(arg1, output_shape) arg1 = broadcast_batch_dims(arg1, output_batch_dims) if arg2.traced: arg2 = broadcast_to(arg2, output_shape) arg2 = broadcast_batch_dims(arg2, output_batch_dims) res = Array( shape=output_shape, dtype=output_dtype, device=arg1.device, materialize=False, name=self.name, batch_dims=output_batch_dims, ) res.set_maxpr(self.maxpr) res.add_arguments(arg1, arg2) res.vjp_rule = self.vjp_rule res.jvp_rule = self.jvp_rule res.custom_kernel_path = self.custom_kernel_path() if not res.stage_realization: self.eagerxpr([arg1, arg2], res) return res
[docs] def compute_output_shape(self, *input_shapes: tuple) -> tuple: """Compute broadcasted output shape.""" if len(input_shapes) != 2: raise ValueError( f"Binary operation requires 2 input shapes, got {len(input_shapes)}" ) shape1, shape2 = input_shapes[0], input_shapes[1] from ..utils.shape_utils import get_broadcasted_shape return get_broadcasted_shape(shape1, shape2)
[docs] def compute_output_dtype(self, arg1: Array, arg2: Array) -> DType: """Default: output dtype same as first input dtype.""" return arg1.dtype
def _validate_inputs(self, arg1: Array, arg2: Array) -> None: """Validate binary operation inputs.""" if not isinstance(arg1, Array) or not isinstance(arg2, Array): raise TypeError("Both arguments must be Array instances") if arg1.dtype != arg2.dtype: raise ValueError(f"Dtypes {arg1.dtype} and {arg2.dtype} are incompatible") if arg1.device != arg2.device: raise ValueError( f"Devices {arg1.device} and {arg2.device} are incompatible" )
[docs] def compute_output_batch_dims(self, *input_batch_dims: tuple) -> tuple: """Default: output batch dims same as input batch dims.""" if len(input_batch_dims) != 2: raise ValueError( f"Binary operation requires 2 input batch dims, got {len(input_batch_dims)}" ) shape1, shape2 = input_batch_dims[0], input_batch_dims[1] from ..utils.shape_utils import get_broadcasted_shape return get_broadcasted_shape(shape1, shape2)
[docs] class ReductionOperation(UnaryOperation): """Base class for reduction operations."""
[docs] def __init__( self, name: str, axes: int | list[int] | tuple[int, ...] | None = None, keep_dims: bool = False, ): super().__init__(name) self.axes = axes self.keep_dims = keep_dims
[docs] def compute_output_shape(self, *input_shapes: tuple) -> tuple: """Compute output shape for reduction.""" if len(input_shapes) != 1: raise ValueError( f"Reduction operation requires 1 input shape, got {len(input_shapes)}" ) input_shape = input_shapes[0] return self._compute_reduction_shape(input_shape, self.axes)
[docs] def compute_output_batch_dims(self, *input_batch_dims: tuple) -> tuple: """Compute output batch dims for reduction.""" if len(input_batch_dims) != 1: raise ValueError( f"Reduction operation requires 1 input batch dims, got {len(input_batch_dims)}" ) # For regular reductions, batch_dims are not affected - they pass through unchanged # Only SumBatchDimsOp overrides this to actually reduce batch dimensions return input_batch_dims[0]
@staticmethod def _compute_reduction_shape( input_shape: tuple, axes: int | list[int] | tuple[int, ...] | None, ) -> tuple: """Compute the output shape for a reduction operation. Always preserves dimensions (sets reduced axes to size 1). Dimension removal should be handled separately by squeeze operations. """ if axes is None: # Reduce all axes - return shape with all dimensions set to 1 return (1,) * len(input_shape) if isinstance(axes, int): axes = [axes] elif isinstance(axes, tuple): axes = list(axes) normalized_axes = [] for axis in axes: if axis < 0: axis += len(input_shape) if axis < 0 or axis >= len(input_shape): raise ValueError( f"Axis {axis} is out of bounds for shape {input_shape}" ) normalized_axes.append(axis) output_shape = [] for i, dim in enumerate(input_shape): if i in normalized_axes: # Always preserve dimensions - set reduced axes to size 1 output_shape.append(1) else: output_shape.append(dim) return tuple(output_shape)
[docs] class ViewOperation(UnaryOperation): """Base class for view operations (reshape, transpose, etc.)."""
[docs] def __init__(self, name: str): super().__init__(name)
[docs] def compute_output_batch_dims(self, *input_batch_dims: tuple) -> tuple: """Default: output batch dims same as input batch dims.""" if len(input_batch_dims) != 1: raise ValueError( f"View operation requires 1 input batch dims, got {len(input_batch_dims)}" ) return input_batch_dims[0]