# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Core Array class with improved organization."""
from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
from typing import Optional, Union
import numpy as np
from max.driver import CPU, Device, Tensor
from max.dtype import DType
from max.graph import TensorValue, TensorValueLike, Value
Shape = tuple[int, ...]
MaxprCallable = Callable[[list[TensorValue], "Array"], None]
VJPRule = Callable[[list["Array"], "Array", "Array"], list["Array"]]
JVPRule = Callable[[list["Array"], list["Array"], "Array"], "Array"]
_DEFAULT_CPU = CPU()
[docs]
class Array:
"""Core tensor-like array class with automatic differentiation support."""
# Class-level type annotations for better Pylance support
shape: Shape
batch_dims: Shape
dtype: DType
device: Device
name: str
args: list[Array]
visited: bool
tensor_value: Optional[Union[Value, TensorValue, TensorValueLike]]
maxpr: Optional[MaxprCallable]
vjp_rule: Optional[VJPRule]
jvp_rule: Optional[JVPRule]
traced: bool
tangent: Optional[Array]
cotangent: Optional[Array]
stage_realization: bool
kernel_impl_path: Optional[Path]
custom_kernel_path: Optional[Path]
_impl: Optional[Union[np.ndarray, Tensor]]
[docs]
def __init__(
self,
shape: Shape,
dtype: DType = DType.float32,
device: Device = _DEFAULT_CPU,
materialize: bool = False,
name: str = "",
batch_dims: Shape = (),
) -> None:
self.shape = shape
self.batch_dims = batch_dims
self.dtype = dtype
self.device = device
self.name = name
self.args: list[Array] = []
self.visited: bool = False
self.tensor_value: Optional[Union[Value, TensorValue, TensorValueLike]] = None
self.maxpr: Optional[MaxprCallable] = None
self.vjp_rule: Optional[VJPRule] = None
self.jvp_rule: Optional[JVPRule] = None
self.traced: bool = False
self.tangent: Optional[Array] = None
self.cotangent: Optional[Array] = None
self.stage_realization: bool = False
self.kernel_impl_path: Optional[Path] = None
self.custom_kernel_path: Optional[Path] = None
# Debug print for newly created arrays
# print(f"[DEBUG] Created array: name='{name}', shape={shape}, dtype={dtype}")
if materialize:
self._impl = Tensor(dtype, batch_dims + shape, device=device)
else:
self._impl = None
@property
def impl(self) -> Optional[Tensor]:
"""Get the max.Tensor representation of this Array. If the underlying _impl field is a Numpy array, convert it to a Tensor."""
if isinstance(self._impl, Tensor):
return self._impl
elif isinstance(self._impl, np.ndarray):
# Convert numpy array to Tensor
return Tensor.from_numpy(self._impl)
else:
return None
[docs]
def impl_(self, value: Optional[Union[np.ndarray, Tensor]]) -> None:
"""Set the implementation of this Array to a Numpy array or Tensor."""
self._impl = value
@property
def size(self) -> int:
"""Return the total number of elements in the array."""
if not self.shape:
return 1 # Scalar array
size = 1
for dim in self.shape:
size *= dim
return size
[docs]
@classmethod
def from_impl(cls, impl: Tensor, name: str = "") -> Array:
"""Create Array from existing Tensor implementation."""
if not isinstance(impl, Tensor):
raise TypeError(f"Data must be a MAX Tensor, got {type(impl)}")
if impl.shape is None:
raise ValueError("Cannot create Array from None shape Tensor")
instance = cls(
shape=impl.shape, dtype=impl.dtype, device=impl.device, materialize=True
)
instance._impl = impl if impl else None
instance.name = name
return instance
[docs]
def copy_from(self, other: Array) -> None:
"""Copy data from another Array."""
if self.shape != other.shape or self.dtype != other.dtype:
raise ValueError("Shape or dtype mismatch for copy")
if other._impl is not None:
self._impl = other._impl.copy()
else:
self._impl = None
[docs]
def add_arguments(self, *arg_nodes: Array) -> None:
"""Add an arguments to this Array's computation graph if traced."""
for arg in arg_nodes:
if not isinstance(arg, Array):
raise TypeError(f"Argument must be an Array, got {type(arg)}")
if arg.traced:
self.traced = True
if arg.stage_realization:
self.stage_realization = True
if self.traced or self.stage_realization:
for arg in arg_nodes:
self.args.append(arg)
[docs]
def realize(self) -> None:
"""Force computation of this Array."""
if self._impl is not None:
return
from .graph_execution import realize_
realize_([self])
if self._impl is None:
raise ValueError("Data is None after realization")
[docs]
def to_numpy(self) -> np.ndarray:
"""Get NumPy representation."""
self.realize() # Ensure the Array is realized before converting
if self._impl is None:
raise ValueError("Cannot get NumPy array from None impl")
if isinstance(self._impl, np.ndarray):
return self._impl
if not isinstance(self._impl, Tensor):
raise TypeError(
f"Cannot convert Array with impl type {type(self._impl)} to NumPy array"
)
return self._impl.to_numpy()
[docs]
@classmethod
def from_numpy(cls, np_array: np.ndarray) -> Array:
"""Create a new Array from a NumPy array."""
if not isinstance(np_array, np.ndarray):
raise TypeError(f"Expected numpy.ndarray, got {type(np_array)}")
array = cls(
shape=np_array.shape,
dtype=DType.from_numpy(np_array.dtype),
device=_DEFAULT_CPU,
name=getattr(np_array, "name", ""),
)
# # WORKAROUND: Handle scalar boolean arrays to avoid MAX library bug
# # The MAX library's tensor.view(DType.bool) fails for scalar tensors
# if np_array.dtype == bool and np_array.shape == ():
# # For scalar boolean, convert to float32 to avoid the bug
# float_array = np_array.astype(np.float32)
# array._impl = float_array#Tensor.from_numpy(float_array)
# # Update the dtype to reflect what we actually stored
# array.dtype = DType.float32
# else:
array._impl = np_array # Tensor.from_numpy(np_array)
array.device = _DEFAULT_CPU
return array
[docs]
def get_arguments(self) -> list[Array]:
"""Get list of argument Arrays."""
return list(self.args)
[docs]
def set_maxpr(self, fn: MaxprCallable) -> None:
"""Set the MAX PR function for this operation."""
self.maxpr = fn
def __repr__(self) -> str:
"""String representation of the Array."""
# self.realize()
from ..utils.formatting import format_shape_and_dtype
if self.impl is not None:
return (
str(self.impl.to(CPU()).to_numpy()) + ":" + format_shape_and_dtype(self)
)
else:
return (
f"Array(shape={self.shape}, dtype={self.dtype}, unrealized):"
+ format_shape_and_dtype(self)
)
[docs]
def to(self, device: Device) -> Array:
"""Move Array to specified device."""
# if self._impl is not of type Tensor, we raise an error
if not isinstance(device, Device):
raise TypeError(f"Expected Device, got {type(device)}")
if self._impl is not None and not isinstance(self._impl, Tensor):
raise TypeError(
f"Cannot transfer Array with impl type {type(self._impl)} to device {device}"
)
if self._impl:
new_impl = self._impl.to(device)
return Array.from_impl(new_impl, name=self.name)
else:
from ..ops.unary import transfer_to
return transfer_to(self, device)
# Operator overloading methods
def __add__(self, other) -> Array:
"""Addition operator."""
from ..ops.binary import add
return add(self, other)
def __mul__(self, other) -> Array:
"""Multiplication operator."""
from ..ops.binary import mul
return mul(self, other)
def __sub__(self, other) -> Array:
"""Subtraction operator."""
from ..ops.binary import sub
return sub(self, other)
def __pow__(self, power) -> Array:
"""Power operator."""
from ..ops.binary import pow as power_op
return power_op(self, power)
def __truediv__(self, other) -> Array:
"""Division operator."""
from ..ops.binary import div
return div(self, other)
def __floordiv__(self, other) -> Array:
"""Floor division operator (//)."""
from ..ops.binary import floordiv
return floordiv(self, other)
def __matmul__(self, other) -> Array:
"""Matrix multiplication operator (@)."""
from ..ops.linalg import matmul
return matmul(self, other)
def __neg__(self) -> Array:
"""Negation operator."""
from ..ops.unary import negate
return negate(self)
def __mod__(self, other) -> Array:
"""Modulo operator (%)."""
from ..ops.binary import mod
return mod(self, other)
# Comparison operators
def __lt__(self, other) -> Array:
"""Less than operator (<)."""
from ..ops.binary import greater_equal
from ..ops.unary import logical_not
# a < b is equivalent to not (a >= b)
return logical_not(greater_equal(self, other))
def __le__(self, other) -> Array:
"""Less than or equal operator (<=)."""
from ..ops.binary import greater_equal
# a <= b is equivalent to b >= a
return greater_equal(other, self)
def __gt__(self, other) -> Array:
"""Greater than operator (>)."""
from ..ops.binary import greater_equal
from ..ops.unary import logical_not
# a > b is equivalent to not (b >= a)
return logical_not(greater_equal(other, self))
def __ge__(self, other) -> Array:
"""Greater than or equal operator (>=)."""
from ..ops.binary import greater_equal
return greater_equal(
self, other
) # Hash and equality for making Arrays usable as dictionary keys
def __hash__(self) -> int:
"""Make Arrays hashable based on object identity.
This allows Arrays to be used as dictionary keys in optimizers.
Two Arrays are considered equal only if they are the same object.
"""
return id(self)
# Reverse operators for when Array is on the right-hand side
def __radd__(self, other) -> Array:
"""Reverse addition operator (other + self)."""
from ..ops.binary import add
return add(other, self)
def __rmul__(self, other) -> Array:
"""Reverse multiplication operator (other * self)."""
from ..ops.binary import mul
return mul(other, self)
def __rsub__(self, other) -> Array:
"""Reverse subtraction operator (other - self)."""
from ..ops.binary import sub
return sub(other, self)
def __rtruediv__(self, other) -> Array:
"""Reverse division operator (other / self)."""
from ..ops.binary import div
return div(other, self)
def __rpow__(self, other) -> Array:
"""Reverse power operator (other ** self)."""
from ..ops.binary import pow as power_op
return power_op(other, self)
def __getitem__(self, key) -> Array:
"""Array slicing using standard Python syntax.
Supports both basic indexing (slices, integers) and advanced indexing (Array indices).
Examples::
arr[1:3] # Slice first dimension
arr[:, 2:5] # Slice second dimension
arr[1:3, 2:5] # Slice multiple dimensions
arr[-2:] # Negative indices
arr[..., :2] # Ellipsis (all dimensions up to last)
# Advanced indexing with Array indices:
indices = nb.array([0, 2, 1])
arr[indices] # Gather elements along first axis
arr[indices, :] # Gather rows
"""
# Check if this is advanced indexing with Array indices
if isinstance(key, Array):
# Single Array index - use gather along axis 0
from ..ops.indexing import gather
return gather(self, key, axis=0)
elif isinstance(key, tuple) and any(isinstance(k, Array) for k in key):
# Mixed indexing with Array indices in tuple
return self._handle_mixed_advanced_indexing(key)
# Handle single slice, integer, or ellipsis (original logic)
if isinstance(key, slice | int | type(...)):
key = (key,)
elif not isinstance(key, tuple):
raise TypeError(
f"Array indices must be integers, slices, ellipsis, Arrays, or tuples, got {type(key)}"
)
# Handle ellipsis expansion
if ... in key:
ellipsis_idx = key.index(...)
# Count non-ellipsis elements
non_ellipsis_count = len([k for k in key if k is not ...])
# Calculate how many slice(None) to insert
missing_dims = len(self.shape) - non_ellipsis_count
if missing_dims < 0:
missing_dims = 0 # Don't allow negative
# Build expanded key
expanded_key = (
key[:ellipsis_idx]
+ (slice(None),) * missing_dims
+ key[ellipsis_idx + 1 :]
)
key = expanded_key
# Special case: if we have indices but the array is scalar, that's an error
if (
len(self.shape) == 0
and len(key) > 0
and not (len(key) == 1 and key[0] is ...)
):
raise IndexError(f"Too many indices for array: expected 0, got {len(key)}")
# Convert integers to slices and build slice list
# Track which dimensions should be squeezed (removed) due to integer indexing
slices = []
squeeze_axes = []
for i, k in enumerate(key):
if i >= len(self.shape):
raise IndexError(
f"Too many indices for array: expected {len(self.shape)}, got {len(key)}"
)
if isinstance(k, int):
# Convert integer index to slice
if k < 0:
# Handle negative indexing
k = self.shape[i] + k
slices.append(slice(k, k + 1))
squeeze_axes.append(i) # Mark this dimension for squeezing
elif isinstance(k, slice):
slices.append(k)
else:
raise TypeError(
f"Array index {i} must be an integer or slice, got {type(k)}"
)
# Create ArraySliceOp with squeeze information
from ..ops.view import ArraySliceOp
op = ArraySliceOp(slices, squeeze_axes)
return op.forward(self)
[docs]
def astype(self, dtype: DType) -> Array:
"""Convert array to a different data type.
Args:
dtype: Target data type
Returns:
New Array with the specified data type
"""
if self.dtype == dtype:
return self # No conversion needed
# Use nabla's cast operation
from ..ops.unary import cast
return cast(self, dtype)
[docs]
def sum(self, axes=None, keep_dims=False) -> Array:
"""Sum array elements over given axes.
Args:
axes: Axis or axes along which to sum. Can be int, list of ints, or None (sum all)
keep_dims: If True, reduced axes are left as dimensions with size 1
Returns:
Array with the sum along the specified axes
Examples::
arr.sum() # Sum all elements
arr.sum(axis=0) # Sum along first axis
arr.sum(axis=[0,1]) # Sum along first two axes
"""
from ..ops.reduce import sum as array_sum
return array_sum(self, axes=axes, keep_dims=keep_dims)
[docs]
def reshape(self, shape: Shape) -> Array:
"""Change the shape of an array without changing its data.
Args:
shape: New shape for the array
Returns:
Array with the new shape
Examples::
arr.reshape((2, 3)) # Reshape to 2x3
arr.reshape((-1,)) # Flatten to 1D (note: -1 not yet supported)
"""
from ..ops.view import reshape
return reshape(arg=self, shape=shape)
[docs]
def permute(self, axes: tuple[int, ...]) -> Array:
"""Permute the dimensions of the array.
Args:
axes: List of integers specifying the new order of dimensions
Returns:
Array with dimensions permuted according to the specified axes
Examples::
arr.permute([1, 0]) # If arr.shape is (2, 3), this will return an array with shape (3, 2)
"""
from ..ops.view import permute
return permute(self, axes)
[docs]
def transpose(self, axes: tuple[int, ...]) -> Array:
"""Permute the dimensions of the array.
Args:
axes: List of integers specifying the new order of dimensions
Returns:
Array with dimensions permuted according to the specified axes
Examples::
arr.permute([1, 0]) # If arr.shape is (2, 3), this will return an array with shape (3, 2)
"""
from ..ops.view import permute
return permute(self, axes)
[docs]
def at(self, key, value):
"""Update array at specified indices/slices, returning new array."""
from ..ops.binary import add, sub
from ..ops.view import pad
# Convert value to Array if needed
if not isinstance(value, Array):
# Match the dtype of the original array
value_np = np.array(value, dtype=self.dtype.to_numpy())
value = Array.from_numpy(value_np)
else:
# If value is already an Array, ensure it matches our dtype
if value.dtype != self.dtype:
value_np = value.to_numpy().astype(self.dtype.to_numpy())
value = Array.from_numpy(value_np)
# Handle single slice, integer, or ellipsis
if isinstance(key, slice | int | type(...)):
key = (key,)
elif not isinstance(key, tuple):
raise TypeError(
f"Array indices must be integers, slices, ellipsis, or tuples, got {type(key)}"
)
# Handle ellipsis expansion (same logic as __getitem__)
if ... in key:
ellipsis_idx = key.index(...)
# Count non-ellipsis elements
non_ellipsis_count = len([k for k in key if k is not ...])
# Calculate how many slice(None) to insert
missing_dims = len(self.shape) - non_ellipsis_count
if missing_dims < 0:
missing_dims = 0 # Don't allow negative
# Build expanded key
expanded_key = (
key[:ellipsis_idx]
+ (slice(None),) * missing_dims
+ key[ellipsis_idx + 1 :]
)
key = expanded_key
# Convert integers to slices for pad operation, handling negative indices
slices = []
for i, k in enumerate(key):
if isinstance(k, int):
# Handle negative indexing before converting to slice
if k < 0:
k = self.shape[i] + k
slices.append(slice(k, k + 1))
elif isinstance(k, slice):
slices.append(k)
else:
raise TypeError(f"Unsupported key type: {type(k)}")
# 1. Slice out the part being replaced (using converted slices for consistency)
sliced_part = self[tuple(slices)]
# 2. Ensure value has the same shape as sliced_part
if value.shape != sliced_part.shape:
# Try to reshape/broadcast value to match sliced shape
value_np = value.to_numpy()
try:
if value_np.size == np.prod(sliced_part.shape):
# Reshape if same number of elements
value = Array.from_numpy(value_np.reshape(sliced_part.shape))
else:
# Try broadcasting
value = Array.from_numpy(
np.broadcast_to(value_np, sliced_part.shape)
)
except:
raise ValueError(
f"Cannot broadcast value shape {value.shape} to sliced shape {sliced_part.shape}"
)
# 3. Calculate the difference
diff = sub(value, sliced_part)
# 4. Pad the difference to full array shape (using converted slices)
padded_diff = pad(diff, slices, self.shape)
# 5. Add to original array
result = add(self, padded_diff)
return result
# Comparison operators
def __eq__(self, other) -> bool:
"""Object identity comparison for hashability.
This returns True only if both Arrays are the same object.
For element-wise comparison, use nb.equal(a, b) explicitly.
"""
return isinstance(other, Array) and self is other
def __ne__(self, other) -> bool:
"""Object identity inequality comparison for hashability.
This returns True if the Arrays are different objects.
For element-wise comparison, use nb.not_equal(a, b) explicitly.
"""
return not self.__eq__(other)
[docs]
def set(self, key, value) -> Array:
"""Set values at specified indices/slices, returning a new array.
This is a functional operation that returns a new Array with the specified
values updated, leaving the original Array unchanged.
Args:
key: Index specification (int, slice, tuple of indices/slices, ellipsis)
value: Value(s) to set at the specified location
Returns:
New Array with updated values
Examples:
new_arr = arr.set(1, 99.0) # Set single element
new_arr = arr.set((1, 2), 99.0) # Set element at (1,2)
new_arr = arr.set(slice(1, 3), 99.0) # Set slice
new_arr = arr.set(..., 99.0) # Set with ellipsis
"""
return self.at(key, value)
def _handle_mixed_advanced_indexing(self, key: tuple) -> Array:
"""Handle mixed indexing with Array indices and slices/integers.
Args:
key: Tuple containing mix of Array indices, slices, and integers
Returns:
Array result of advanced indexing
"""
from ..ops.indexing import gather
# For now, implement a simplified version that handles the most common case:
# Array index in first position, followed by slices/integers
# More complex cases can be added later
# Find the first Array index
array_index_pos = None
for i, k in enumerate(key):
if isinstance(k, Array):
if array_index_pos is None:
array_index_pos = i
else:
# Multiple Array indices - more complex case
raise NotImplementedError(
"Multiple Array indices not yet supported. "
"Use gather/scatter operations directly for complex indexing."
)
if array_index_pos is None:
# No Array indices found - shouldn't reach here
raise ValueError("Expected Array index in mixed indexing")
array_index = key[array_index_pos]
if array_index_pos == 0:
# Array index in first position: arr[indices, slice1, slice2, ...]
remaining_key = key[1:]
# First apply gather along axis 0
gathered = gather(self, array_index, axis=0)
# Then apply remaining indexing if any
if remaining_key:
# The remaining key should be applied starting from the first dimension
# after the array-indexed dimension. Since we array-indexed dimension 0,
# the remaining key applies to dimensions 1, 2, 3, ... of the original shape
# which are dimensions 1, 2, 3, ... of the gathered result.
# So we need to prepend a slice(None) to cover the new first dimension from gather
full_key = (slice(None),) + remaining_key
return gathered[full_key]
else:
return gathered
else:
# Array index not in first position - more complex
# For now, we'll convert to a sequence of operations
# This is a simplified implementation
raise NotImplementedError(
f"Array index at position {array_index_pos} not yet supported. "
"Use gather operation directly or put Array index first."
)
def __setitem__(self, key, value) -> None:
"""Array assignment using standard Python syntax.
Supports both basic assignment (slices, integers) and advanced assignment (Array indices).
Examples::
arr[1:3] = value # Assign to slice
arr[:, 2:5] = value # Assign to slice in second dimension
# Advanced indexing with Array indices:
indices = nb.array([0, 2, 1])
arr[indices] = value # Scatter values to specified indices
"""
# Convert value to Array if needed
if not isinstance(value, Array):
from ..ops.creation import array
value = array(value)
# Check if this is advanced indexing with Array indices
if isinstance(key, Array):
# Single Array index - use scatter along axis 0
self._setitem_with_array_index(key, value, axis=0)
elif isinstance(key, tuple) and any(isinstance(k, Array) for k in key):
# Mixed indexing with Array indices
self._setitem_mixed_advanced_indexing(key, value)
else:
# Basic indexing - not implemented for now
raise NotImplementedError(
"Basic slice assignment not yet implemented. "
"Use Array indices for scatter operations."
)
def _setitem_with_array_index(
self, indices: Array, values: Array, axis: int = 0
) -> None:
"""Helper method for setitem with Array indices.
Args:
indices: Array of indices where to place values
values: Array of values to place
axis: Axis along which to scatter
"""
from ..ops.indexing import scatter
# Create new array by scattering values into a copy of self
# Note: This creates a new array rather than in-place modification
# In-place modification would require mutable arrays
new_array = scatter(
target_shape=self.shape, indices=indices, values=values, axis=axis
)
# Update self's implementation to point to new data
# This simulates in-place modification
self._impl = new_array._impl
def _setitem_mixed_advanced_indexing(self, key: tuple, value: Array) -> None:
"""Helper method for mixed advanced indexing assignment.
Args:
key: Tuple containing mix of Array indices, slices, and integers
value: Array to assign
"""
# For now, implement a simplified version
# Find the first Array index
array_index_pos = None
for i, k in enumerate(key):
if isinstance(k, Array):
if array_index_pos is None:
array_index_pos = i
else:
raise NotImplementedError(
"Multiple Array indices not yet supported"
)
if array_index_pos != 0:
raise NotImplementedError(
"Array index must be in first position for assignment"
)
array_index = key[0]
remaining_key = key[1:]
if remaining_key:
# Need to handle partial assignment like arr[indices, :, slice] = value
raise NotImplementedError(
"Mixed Array index with slices in assignment not yet supported"
)
else:
# Simple case: arr[indices] = value
self._setitem_with_array_index(array_index, value, axis=0)