# # ===----------------------------------------------------------------------=== #
# # Nabla 2025
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
# # ===----------------------------------------------------------------------=== #
# """View and shape manipulation operations."""
# import numpy as np
# from max.driver import Tensor
# from max.dtype import DType
# from max.graph import TensorValue, ops
# from ..core.array import Array, Shape
# from .operation import Operation, ViewOperation
# # Public API
# __all__ = [
# "transpose",
# "permute",
# "move_axis_to_front",
# "move_axis_from_front",
# "permute_batch_dims",
# "move_axis_to_front_of_batch_dims",
# "move_axis_from_front_of_batch_dims",
# "reshape",
# "broadcast_to",
# "broadcast_batch_dims",
# "squeeze",
# "unsqueeze",
# "squeeze_batch_dims",
# "unsqueeze_batch_dims",
# "shallow_copy",
# "array_slice",
# "pad",
# "concatenate",
# "stack",
# # "scatter",
# # "gather",
# ]
# class TransposeOp(ViewOperation):
# """Matrix/tensor transpose operation."""
# def __init__(self, axis_1: int = -2, axis_2: int = -1):
# super().__init__(f"transpose[permutation=({axis_1},{axis_2})]")
# self.axis_1 = axis_1
# self.axis_2 = axis_2
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compute output shape for transpose operation with compatible signature."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Transpose operation requires 1 input shape, got {len(input_shapes)}"
# )
# arg_shape = input_shapes[0]
# if not arg_shape:
# raise ValueError("Cannot transpose an empty shape")
# axis_1 = self.axis_1 if self.axis_1 >= 0 else len(arg_shape) + self.axis_1
# axis_2 = self.axis_2 if self.axis_2 >= 0 else len(arg_shape) + self.axis_2
# if axis_1 < 0 or axis_1 >= len(arg_shape):
# raise ValueError(f"axis_1 {axis_1} is out of bounds for shape {arg_shape}")
# if axis_2 < 0 or axis_2 >= len(arg_shape):
# raise ValueError(f"axis_2 {axis_2} is out of bounds for shape {arg_shape}")
# new_shape = list(arg_shape)
# new_shape[axis_1], new_shape[axis_2] = new_shape[axis_2], new_shape[axis_1]
# return tuple(new_shape)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# output.tensor_value = ops.transpose(args[0], self.axis_1, self.axis_2)
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# offset = len(args[0].batch_dims)
# axes = list(range(-offset - len(args[0].shape), 0))
# axes[self.axis_1], axes[self.axis_2] = axes[self.axis_2], axes[self.axis_1]
# np_result = np.transpose(args[0].to_numpy(), axes)
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# return [transpose(cotangent, self.axis_1, self.axis_2)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# return transpose(tangents[0], self.axis_1, self.axis_2)
# def transpose(arg: Array, axis_1: int = -2, axis_2: int = -1) -> Array:
# """Transpose array along two axes."""
# axis_1 = axis_1 if axis_1 < 0 else -len(arg.shape) + axis_1
# axis_2 = axis_2 if axis_2 < 0 else -len(arg.shape) + axis_2
# if axis_1 == axis_2 or len(arg.shape) <= 1:
# return arg
# if axis_1 < -len(arg.shape) or axis_2 < -len(arg.shape):
# raise ValueError(
# f"Invalid axes {axis_1}, {axis_2} for shape {arg.shape}. "
# "Axes must be within the range of the array dimensions."
# )
# op = TransposeOp(axis_1, axis_2)
# return op.forward(arg)
# class TransposeBatchDimsOp(ViewOperation):
# """Transpose operation to swap two batch dimensions."""
# def __init__(self, axis_1: int = -2, axis_2: int = -1):
# """Initialize transpose batch dims operation.
# Args:
# axis_1: First batch dimension axis to swap (negative indices preferred)
# axis_2: Second batch dimension axis to swap (negative indices preferred)
# """
# super().__init__(f"transpose_batch_dims[permutation=({axis_1},{axis_2})]")
# self.axis_1 = axis_1
# self.axis_2 = axis_2
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Shape stays the same for batch dimension operations."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Transpose batch dims operation requires 1 input shape, got {len(input_shapes)}"
# )
# return input_shapes[0]
# def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
# """Compute output batch_dims after transposing two axes."""
# if len(input_batch_dimss) != 1:
# raise ValueError(
# f"Transpose batch dims operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
# )
# input_batch_dims = input_batch_dimss[0]
# if not input_batch_dims:
# raise ValueError(
# "Cannot transpose batch dims of an array with no batch dimensions"
# )
# # Convert negative indices to positive for validation and computation
# axis_1 = self.axis_1 + len(input_batch_dims) if self.axis_1 < 0 else self.axis_1
# axis_2 = self.axis_2 + len(input_batch_dims) if self.axis_2 < 0 else self.axis_2
# # Validate axes are within bounds
# if axis_1 < 0 or axis_1 >= len(input_batch_dims):
# raise ValueError(
# f"axis_1 {self.axis_1} is out of bounds for batch_dims {input_batch_dims}"
# )
# if axis_2 < 0 or axis_2 >= len(input_batch_dims):
# raise ValueError(
# f"axis_2 {self.axis_2} is out of bounds for batch_dims {input_batch_dims}"
# )
# # Create new batch_dims with axes swapped
# new_batch_dims = list(input_batch_dims)
# new_batch_dims[axis_1], new_batch_dims[axis_2] = (
# new_batch_dims[axis_2],
# new_batch_dims[axis_1],
# )
# return tuple(new_batch_dims)
# def forward(self, *args: Array) -> Array:
# """Override forward to handle single input."""
# if len(args) != 1:
# raise ValueError(
# f"Transpose batch dims operation requires 1 argument, got {len(args)}"
# )
# return super().forward(*args)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# """MAX graph implementation using ops.transpose."""
# axis_1 = self.axis_1 - len(output.shape)
# axis_2 = self.axis_2 - len(output.shape)
# output.tensor_value = ops.transpose(args[0], axis_1, axis_2)
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# """Eager execution using NumPy transpose."""
# input_array = args[0]
# # Get the full tensor including batch dimensions
# input_np = input_array.to_numpy()
# axis_1 = self.axis_1 - len(args[0].shape)
# axis_2 = self.axis_2 - len(args[0].shape)
# # Create axes list for full transpose
# total_dims = len(input_array.batch_dims) + len(input_array.shape)
# axes = list(range(total_dims))
# # Swap the two batch dimension axes
# axes[axis_1], axes[axis_2] = axes[axis_2], axes[axis_1]
# # Apply transpose
# np_result = np.transpose(input_np, axes)
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# """VJP rule: transpose is its own inverse."""
# return [transpose_batch_dims(cotangent, self.axis_1, self.axis_2)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# """JVP rule: apply same transpose to tangents."""
# return transpose_batch_dims(tangents[0], self.axis_1, self.axis_2)
# def transpose_batch_dims(arg: Array, axis_1: int = -2, axis_2: int = -1) -> Array:
# """Transpose batch dimensions along two axes.
# This operation swaps two axes in the batch_dims of an Array, similar to how
# regular transpose works on shape dimensions. The shape dimensions remain unchanged.
# Args:
# arg: Input array with batch dimensions to transpose
# axis_1: First batch dimension axis to swap (default: -2)
# axis_2: Second batch dimension axis to swap (default: -1)
# Returns:
# Array with specified batch dimensions transposed
# Example:
# >>> import nabla as nb
# >>> # Array with batch_dims=(2, 3, 4) and shape=(5, 6)
# >>> x = nb.ones((5, 6))
# >>> x.batch_dims = (2, 3, 4) # Simulated for example
# >>> y = transpose_batch_dims(x, -3, -1) # Swap first and last batch dims
# >>> # Result has batch_dims=(4, 3, 2) and shape=(5, 6)
# """
# # Convert to negative indices for consistency with batch dimension handling
# axis_1 = axis_1 if axis_1 < 0 else -len(arg.batch_dims) + axis_1
# axis_2 = axis_2 if axis_2 < 0 else -len(arg.batch_dims) + axis_2
# op = TransposeBatchDimsOp(axis_1, axis_2)
# return op.forward(arg)
# def compute_iterative_transpose_swaps(axes: tuple[int, ...]) -> list[tuple[int, int]]:
# """Compute the sequence of axis swaps needed to implement a permutation.
# This function implements the algorithm from the Mojo code to determine what
# sequence of axis swaps (transposes) are needed to achieve a given permutation.
# Returns a list of (axis1, axis2) tuples that should be swapped in order.
# Args:
# axes: Tuple of axis indices specifying the desired permutation.
# All indices should be negative (e.g., -1, -2, -3, ...)
# Returns:
# List of (axis1, axis2) tuples representing the swaps to perform in order.
# Each tuple contains two negative axis indices to be swapped.
# The algorithm works as follows:
# 1. Initialize current_axis_order as [-num_dims, -num_dims+1, ..., -1]
# 2. For each target position x, find where target_axis currently is (y)
# 3. If x != y, record the swap (x_neg, y_neg) and update current_axis_order
# 4. Return the list of all recorded swaps
# """
# target_perm = list(axes)
# num_dims = len(target_perm)
# # Initialize current_axis_order as in Mojo: [-num_dims, -num_dims+1, ..., -1]
# current_axis_order = []
# for i in range(-num_dims, 0):
# current_axis_order.append(i)
# swaps = []
# # For each target position x, move the correct axis there
# for x in range(num_dims):
# target_axis = target_perm[x]
# # Find where target_axis currently is in the current ordering
# try:
# y = current_axis_order.index(target_axis)
# except ValueError as e:
# # target_axis not found in current_axis_order, this shouldn't happen
# raise ValueError(
# f"Target axis {target_axis} not found in current_axis_order {current_axis_order}"
# ) from e
# # If already in the right position, skip
# if x == y:
# continue
# # Convert to negative indices for the swap operation
# x_neg = x - num_dims
# y_neg = y - num_dims
# # Record the swap
# swaps.append((x_neg, y_neg))
# # Update current_axis_order to reflect the swap
# # Swap the elements at positions x and y
# current_axis_order[x], current_axis_order[y] = (
# current_axis_order[y],
# current_axis_order[x],
# )
# return swaps
# class PermuteOp(ViewOperation):
# """Permute (reorder) the dimensions of a tensor according to given axes."""
# def __init__(self, axes: tuple[int, ...]):
# """Initialize permute operation.
# Args:
# axes: Tuple of axis indices specifying the new order.
# Must be a permutation of range(ndim).
# """
# super().__init__(f"permute[axes={axes}]")
# self.axes = tuple(axes)
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compute output shape after permutation."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Permute operation requires 1 input shape, got {len(input_shapes)}"
# )
# input_shape = input_shapes[0]
# # Validate axes - should now be all negative and same length as input
# if len(self.axes) != len(input_shape):
# raise ValueError(
# f"Number of axes {len(self.axes)} must match input dimensions {len(input_shape)}"
# )
# # Convert to positive indices for validation
# positive_axes = [ax + len(input_shape) for ax in self.axes]
# if sorted(positive_axes) != list(range(len(input_shape))):
# raise ValueError(
# f"Axes {self.axes} must be a permutation of negative indices corresponding to {list(range(len(input_shape)))}"
# )
# # Reorder dimensions according to axes (convert negative to positive for indexing)
# return tuple(input_shape[axis + len(input_shape)] for axis in self.axes)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# """Max computation: permute the tensor using iterative transpose."""
# # Get the sequence of swaps needed for this permutation
# swaps = compute_iterative_transpose_swaps(self.axes)
# # Apply each swap in sequence
# out_symbol = args[0]
# for axis1, axis2 in swaps:
# out_symbol = ops.transpose(out_symbol, axis1, axis2)
# output.tensor_value = out_symbol
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# """Eager computation: permute using numpy."""
# # Handle batch dimensions properly like transpose does
# offset = len(args[0].batch_dims)
# # Convert our negative axes (relative to array shape) to work with full numpy array
# numpy_axes = []
# for ax in self.axes:
# # ax is negative relative to args[0].shape, convert to positive
# array_pos_ax = ax + len(args[0].shape)
# # Now convert to position in full numpy array (including batch dims)
# numpy_pos_ax = offset + array_pos_ax
# numpy_axes.append(numpy_pos_ax)
# # Prepend batch dimension indices (they stay in their original positions)
# full_axes = list(range(offset)) + numpy_axes
# np_result = np.transpose(args[0].to_numpy(), full_axes)
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# """VJP rule: reverse the permutation."""
# # Create inverse permutation for negative indices
# inv_axes = [0] * len(self.axes)
# for i, axis in enumerate(self.axes):
# # Convert negative axis to positive index for inverse mapping
# pos_axis = axis + len(self.axes)
# inv_axes[pos_axis] = i
# # Convert back to negative indices
# inv_axes_negative = [-len(self.axes) + ax for ax in inv_axes]
# return [permute(cotangent, tuple(inv_axes_negative))]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# """JVP rule: apply same permutation to tangent."""
# return permute(tangents[0], self.axes)
# def permute(input_array: Array, axes: tuple[int, ...]) -> Array:
# """Permute (reorder) the dimensions of a tensor.
# Args:
# input_array: Input tensor
# axes: Tuple specifying the new order of dimensions
# Returns:
# Tensor with reordered dimensions
# Example:
# >>> x = nb.ones((2, 3, 4)) # shape (2, 3, 4)
# >>> y = permute(x, (2, 0, 1)) # shape (4, 2, 3)
# >>> # Dimension 2 -> position 0, dimension 0 -> position 1, dimension 1 -> position 2
# """
# # always store axes to be fully negative
# axes = tuple(-len(input_array.shape) + ax if ax >= 0 else ax for ax in axes)
# # but first we add oentailly missing axes which we treat as unpemruted
# axes_new = []
# for i in range(-len(input_array.shape), -len(axes)):
# axes_new.append(i)
# axes = tuple(axes_new + list(axes)) # prepend missing axes to the front
# op = PermuteOp(axes)
# return op.forward(input_array)
# class PermuteBatchDimsOp(ViewOperation):
# """Permute (reorder) the batch dimensions of an array according to given axes."""
# def __init__(self, axes: tuple[int, ...]):
# """Initialize permute batch dims operation.
# Args:
# axes: Tuple of axis indices specifying the new order for batch_dims.
# Must be a permutation of range(-len(batch_dims), 0).
# All indices should be negative.
# """
# super().__init__(f"permute_batch_dims[axes={axes}]")
# self.axes = tuple(axes)
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Shape stays the same for batch dimension operations."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Permute batch dims operation requires 1 input shape, got {len(input_shapes)}"
# )
# return input_shapes[0]
# def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
# """Compute output batch_dims after permutation."""
# if len(input_batch_dimss) != 1:
# raise ValueError(
# f"Permute batch dims operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
# )
# input_batch_dims = input_batch_dimss[0]
# if not input_batch_dims:
# raise ValueError(
# "Cannot permute batch dims of an array with no batch dimensions"
# )
# # Validate axes - should be all negative and same length as input batch_dims
# if len(self.axes) != len(input_batch_dims):
# raise ValueError(
# f"Number of axes {len(self.axes)} must match input batch dimensions {len(input_batch_dims)}"
# )
# # Convert to positive indices for validation
# positive_axes = [ax + len(input_batch_dims) for ax in self.axes]
# if sorted(positive_axes) != list(range(len(input_batch_dims))):
# raise ValueError(
# f"Axes {self.axes} must be a permutation of negative indices corresponding to batch_dims range"
# )
# # Reorder batch dimensions according to axes (convert negative to positive for indexing)
# return tuple(
# input_batch_dims[axis + len(input_batch_dims)] for axis in self.axes
# )
# def forward(self, *args: Array) -> Array:
# """Override forward to handle single input."""
# if len(args) != 1:
# raise ValueError(
# f"Permute batch dims operation requires 1 argument, got {len(args)}"
# )
# return super().forward(*args)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# """MAX graph implementation using ops.transpose."""
# # Get the sequence of swaps needed for this permutation
# swaps = compute_iterative_transpose_swaps(self.axes)
# swaps = [
# (axis1 - len(output.shape), axis2 - len(output.shape))
# for axis1, axis2 in swaps
# ]
# # Apply each swap in sequence
# out_symbol = args[0]
# for axis1, axis2 in swaps:
# out_symbol = ops.transpose(out_symbol, axis1, axis2)
# output.tensor_value = out_symbol
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# """Eager execution using NumPy transpose."""
# input_array = args[0]
# # Get the full tensor including batch dimensions
# input_np = input_array.to_numpy()
# # Convert batch dimension axes to full tensor indices
# # Following the pattern from other batch operations
# numpy_axes = []
# for ax in self.axes:
# # ax is negative relative to batch_dims, convert to full tensor position
# batch_pos_ax = ax - len(input_array.shape)
# numpy_axes.append(batch_pos_ax)
# # Add shape dimension indices (they stay in their original relative positions)
# # They come after the batch dimensions in the permuted tensor
# shape_offset = len(input_array.batch_dims)
# shape_axes = list(range(shape_offset, shape_offset + len(input_array.shape)))
# full_axes = numpy_axes + shape_axes
# # Apply transpose
# np_result = np.transpose(input_np, full_axes)
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# # """VJP rule: reverse the permutation."""
# # Create inverse permutation for negative indices
# inv_axes = [0] * len(self.axes)
# for i, axis in enumerate(self.axes):
# # Convert negative axis to positive index for inverse mapping
# pos_axis = axis + len(self.axes)
# inv_axes[pos_axis] = i
# # Convert back to negative indices
# inv_axes_negative = [-len(self.axes) + ax for ax in inv_axes]
# return [permute_batch_dims(cotangent, tuple(inv_axes_negative))]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# """JVP rule: apply same permutation to tangent."""
# return permute_batch_dims(tangents[0], self.axes)
# def permute_batch_dims(input_array: Array, axes: tuple[int, ...]) -> Array:
# """Permute (reorder) the batch dimensions of an array.
# This operation reorders the batch_dims of an Array according to the given axes,
# similar to how regular permute works on shape dimensions. The shape dimensions
# remain unchanged.
# Args:
# input_array: Input array with batch dimensions to permute
# axes: Tuple specifying the new order of batch dimensions.
# All indices should be negative and form a permutation.
# Returns:
# Array with batch dimensions reordered according to axes
# Example:
# >>> import nabla as nb
# >>> # Array with batch_dims=(2, 3, 4) and shape=(5, 6)
# >>> x = nb.ones((5, 6))
# >>> x.batch_dims = (2, 3, 4) # Simulated for example
# >>> y = permute_batch_dims(x, (-1, -3, -2)) # Reorder as (4, 2, 3)
# >>> # Result has batch_dims=(4, 2, 3) and shape=(5, 6)
# """
# if len(axes) <= 1:
# return input_array # No permutation needed for single axis or empty
# # Convert to negative indices for consistency with batch dimension handling
# axes = tuple(-len(input_array.batch_dims) + ax if ax >= 0 else ax for ax in axes)
# # Handle case where fewer axes are provided - prepend missing axes to front
# if len(axes) < len(input_array.batch_dims):
# axes_new = []
# for i in range(-len(input_array.batch_dims), -len(axes)):
# axes_new.append(i)
# axes = tuple(axes_new) + axes
# op = PermuteBatchDimsOp(axes)
# return op.forward(input_array)
# def move_axis_to_front(input_array: Array, axis: int) -> Array:
# """Move specified axis to the front (position 0), shifting others right.
# Args:
# input_array: Input tensor
# axis: Axis to move to front
# Returns:
# Tensor with specified axis moved to front
# Example:
# >>> x = nb.ones((2, 3, 4)) # shape (2, 3, 4)
# >>> y = move_axis_to_front(x, 2) # shape (4, 2, 3)
# >>> # axis 2 moved to front, others shifted: [2, 0, 1]
# """
# ndim = len(input_array.shape)
# # Normalize negative axis
# if axis < 0:
# axis = ndim + axis
# if axis < 0 or axis >= ndim:
# raise ValueError(f"Axis {axis} out of bounds for array of dimension {ndim}")
# # Generate permutation: [axis, 0, 1, ..., axis-1, axis+1, ..., ndim-1]
# axes = [axis] + [i for i in range(ndim) if i != axis]
# return permute(input_array, tuple(axes))
# def move_axis_to_back(input_array: Array, axis: int) -> Array:
# """Move specified axis to the back (last position), shifting others left.
# Args:
# input_array: Input tensor
# axis: Axis to move to back
# Returns:
# Tensor with specified axis moved to back
# Example:
# >>> x = nb.ones((2, 3, 4)) # shape (2, 3, 4)
# >>> y = move_axis_to_back(x, 0) # shape (3, 4, 2)
# >>> # axis 0 moved to back, others shifted: [1, 2, 0]
# """
# ndim = len(input_array.shape)
# # Normalize negative axis
# if axis < 0:
# axis = ndim + axis
# if axis < 0 or axis >= ndim:
# raise ValueError(f"Axis {axis} out of bounds for array of dimension {ndim}")
# # Generate permutation: [0, 1, ..., axis-1, axis+1, ..., ndim-1, axis]
# axes = [i for i in range(ndim) if i != axis] + [axis]
# return permute(input_array, tuple(axes))
# def move_axis_from_front(input_array: Array, target_axis: int) -> Array:
# """Move front axis (position 0) to specified target position.
# Args:
# input_array: Input tensor (assumes front axis is the one to move)
# target_axis: Target position for the front axis
# Returns:
# Tensor with front axis moved to target position
# Example:
# >>> x = nb.ones((4, 2, 3)) # front axis has size 4
# >>> y = move_axis_from_front(x, 2) # shape (2, 3, 4)
# >>> # front axis moved to position 2: [1, 2, 0]
# """
# ndim = len(input_array.shape)
# # Normalize negative axis
# if target_axis < 0:
# target_axis = ndim + target_axis
# if target_axis < 0 or target_axis >= ndim:
# raise ValueError(
# f"Target axis {target_axis} out of bounds for array of dimension {ndim}"
# )
# if target_axis == 0:
# return input_array # Already at front
# # Generate permutation to move front to target_axis
# # [1, 2, ..., target_axis, 0, target_axis+1, ..., ndim-1]
# axes = list(range(1, target_axis + 1)) + [0] + list(range(target_axis + 1, ndim))
# return permute(input_array, tuple(axes))
# def move_axis_from_back(input_array: Array, target_axis: int) -> Array:
# """Move back axis (last position) to specified target position.
# Args:
# input_array: Input tensor (assumes back axis is the one to move)
# target_axis: Target position for the back axis
# Returns:
# Tensor with back axis moved to target position
# Example:
# >>> x = nb.ones((4, 2, 3)) # back axis has size 3
# >>> y = move_axis_from_back(x, 1) # shape (2, 4, 3)
# >>> # back axis moved to position 1: [0, 2, 1]
# """
# ndim = len(input_array.shape)
# # Normalize negative axis
# if target_axis < 0:
# target_axis = ndim + target_axis
# if target_axis < 0 or target_axis >= ndim:
# raise ValueError(
# f"Target axis {target_axis} out of bounds for array of dimension {ndim}"
# )
# if target_axis == ndim - 1:
# return input_array # Already at back
# # Generate permutation to move back to target_axis
# axes = list(range(0, target_axis)) + [ndim - 1] + list(range(target_axis, ndim - 1))
# return permute(input_array, tuple(axes))
# def move_axis_to_front_of_batch_dims(input_array: Array, axis: int) -> Array:
# """Move specified batch dimension to the front (position 0), shifting others right.
# Args:
# input_array: Input tensor with batch dimensions
# axis: Batch dimension to move to front (negative index)
# Returns:
# Tensor with specified batch dimension moved to front
# Example:
# >>> x = nb.ones((2, 3, 4)) # shape (2, 3, 4)
# >>> x.batch_dims = (1, 0) # Simulated for example
# >>> y = move_axis_to_fron_of_batch_dims(x, -1) # Move last batch dim to front
# >>> # Result has batch_dims=(0, 1) and shape=(2, 3, 4)
# """
# ndim = len(input_array.batch_dims)
# # Normalize negative axis
# if axis >= 0:
# axis = -len(input_array.batch_dims) + axis
# if axis < -len(input_array.batch_dims) or axis >= 0:
# raise ValueError(
# f"Axis {axis} out of bounds for batch_dims of dimension {ndim}"
# )
# # Generate permutation: [axis, 0, 1, ..., axis-1, axis+1, ..., ndim-1]
# axes = [axis] + [i for i in range(-len(input_array.batch_dims), 0) if i != axis]
# return permute_batch_dims(input_array, tuple(axes))
# def move_axis_from_front_of_batch_dims(input_array: Array, target_axis: int) -> Array:
# """Move front batch dimension (position 0) to specified target position.
# Args:
# input_array: Input tensor with batch dimensions (assumes front batch dim is the one to move)
# target_axis: Target position for the front batch dimension (negative index)
# Returns:
# Tensor with front batch dimension moved to target position
# Example:
# >>> x = nb.ones((4, 2, 3)) # shape (4, 2, 3)
# >>> x.batch_dims = (0, 1) # Simulated for example
# >>> y = move_axis_from_front_of_batch_dims(x, -1) # Move front batch dim to last position
# >>> # Result has batch_dims=(1, 0) and shape=(4, 2, 3)
# """
# ndim = len(input_array.batch_dims)
# # Normalize negative axis
# if target_axis >= 0:
# target_axis = -len(input_array.batch_dims) + target_axis
# if target_axis < -len(input_array.batch_dims) or target_axis >= 0:
# raise ValueError(
# f"Target axis {target_axis} out of bounds for batch_dims of dimension {ndim}"
# )
# if target_axis == 0:
# return input_array # Already at front
# # Generate permutation to move front to target_axis
# axes = (
# list(range(-len(input_array.batch_dims) + 1, target_axis + 1))
# + [0]
# + list(range(target_axis + 1, 0))
# )
# return permute_batch_dims(input_array, tuple(axes))
# class ReshapeOp(ViewOperation):
# """Reshape operation."""
# def __init__(self, arg_shape: Shape, target_shape: Shape):
# super().__init__(f"reshape[new_sizes={target_shape}]")
# self.arg_shape = arg_shape
# self.target_shape = target_shape
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compatible signature."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Reshape operation requires 1 input shape, got {len(input_shapes)}"
# )
# return self.target_shape
# def forward(self, *args: Array) -> Array:
# """Override forward to validate size compatibility with compatible signature."""
# if len(args) != 1:
# raise ValueError(f"Reshape operation requires 1 argument, got {len(args)}")
# arg = args[0]
# old_size = np.prod(arg.shape) if arg.shape else 1
# new_size = np.prod(self.target_shape) if self.target_shape else 1
# if old_size != new_size:
# raise ValueError(
# f"Cannot reshape array of size {old_size} to shape {self.target_shape} of size {new_size}"
# )
# return super().forward(arg)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# output.tensor_value = ops.reshape(
# args[0], output.batch_dims + self.target_shape
# )
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# np_result = np.reshape(
# args[0].to_numpy(), output.batch_dims + self.target_shape
# )
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# return [reshape(cotangent, self.arg_shape)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# return reshape(tangents[0], self.target_shape)
# def reshape(arg: Array, shape: Shape) -> Array:
# """Reshape array to given shape."""
# # Handle -1 dimension inference
# if -1 in shape:
# # Compute the inferred dimension
# total_size = np.prod(arg.shape) if arg.shape else 1
# known_size = 1
# unknown_idx = -1
# for i, dim in enumerate(shape):
# if dim == -1:
# if unknown_idx != -1:
# raise ValueError("Can only specify one unknown dimension with -1")
# unknown_idx = i
# else:
# known_size *= dim
# if unknown_idx == -1:
# # No -1 found, use shape as is
# target_shape = shape
# else:
# # Calculate the unknown dimension
# if known_size == 0:
# raise ValueError(
# "Cannot infer shape when known dimensions have zero size"
# )
# if total_size % known_size != 0:
# raise ValueError(
# f"Cannot reshape array of size {total_size} to shape {shape}"
# )
# inferred_dim = total_size // known_size
# target_shape = tuple(
# int(inferred_dim if dim == -1 else dim) for dim in shape
# )
# else:
# target_shape = tuple(int(dim) for dim in shape)
# op = ReshapeOp(arg.shape, target_shape)
# return op.forward(arg)
# class BroadcastToOp(ViewOperation):
# """Broadcast array to target shape."""
# def __init__(self, target_shape: Shape):
# super().__init__(f"broadcast[shape={target_shape}]")
# self.target_shape = target_shape
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compatible signature."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Broadcast operation requires 1 input shape, got {len(input_shapes)}"
# )
# return self.target_shape
# def forward(self, *args: Array) -> Array:
# """Override forward to handle case where no broadcasting needed with compatible signature."""
# if len(args) != 1:
# raise ValueError(
# f"Broadcast operation requires 1 argument, got {len(args)}"
# )
# arg = args[0]
# if arg.shape == self.target_shape:
# return arg
# return super().forward(*args)
# @staticmethod
# def get_broadcasted_axes(input_shape: Shape, target_shape: Shape) -> list[int]:
# """Get axes that were broadcasted (for VJP)."""
# if len(input_shape) > len(target_shape):
# raise ValueError(
# f"Input shape {input_shape} cannot be broadcast to {target_shape}"
# )
# broadcasted_axes = []
# padded_input = (1,) * (len(target_shape) - len(input_shape)) + input_shape
# for i in range(len(target_shape)):
# if padded_input[i] == 1 and target_shape[i] > 1:
# # Return negative index to reference from the right side
# # This ensures we sum over the correct dimension
# broadcasted_axes.append(i - len(target_shape))
# elif padded_input[i] != target_shape[i] and padded_input[i] != 1:
# raise ValueError(f"Cannot broadcast {input_shape} to {target_shape}")
# return broadcasted_axes
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# output.tensor_value = ops.broadcast_to(
# args[0], output.batch_dims + self.target_shape
# )
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# np_result = np.broadcast_to(
# args[0].to_numpy(), shape=output.batch_dims + self.target_shape
# )
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# broadcasted_axes = self.get_broadcasted_axes(
# primals[0].shape, self.target_shape
# )
# from .reduce import sum as sum_op # Renamed to avoid shadowing built-in
# return [sum_op(cotangent, axes=broadcasted_axes, keep_dims=True)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# return broadcast_to(tangents[0], self.target_shape)
# def broadcast_to(arg: Array, shape: Shape) -> Array:
# """Broadcast array to target shape."""
# if arg.shape == shape:
# return arg
# for _ in range(len(shape) - len(arg.shape)):
# arg = unsqueeze(arg, [0])
# op = BroadcastToOp(shape)
# return op.forward(arg)
# class BroadcastBatchDimsOp(ViewOperation):
# """Broadcast array to target batch_dims."""
# def __init__(self, target_batch_dims: Shape):
# super().__init__(f"broadcast_batch_dims[shape={target_batch_dims}]")
# self.target_batch_dims = target_batch_dims
# def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
# """Compatible signature."""
# if len(input_batch_dimss) != 1:
# raise ValueError(
# f"Broadcast operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
# )
# return self.target_batch_dims
# def forward(self, *args: Array) -> Array:
# """Override forward to handle case where no broadcasting needed with compatible signature."""
# if len(args) != 1:
# raise ValueError(
# f"Broadcast operation requires 1 argument, got {len(args)}"
# )
# arg = args[0]
# if arg.batch_dims == self.target_batch_dims:
# return arg
# return super().forward(*args)
# @staticmethod
# def get_broadcasted_axes(
# input_batch_dims: Shape, target_batch_dims: Shape
# ) -> list[int]:
# """Get axes that were broadcasted (for VJP)."""
# if len(input_batch_dims) > len(target_batch_dims):
# raise ValueError(
# f"Input batch_dims {input_batch_dims} cannot be broadcast to {target_batch_dims}"
# )
# broadcasted_axes = []
# padded_input = (1,) * (
# len(target_batch_dims) - len(input_batch_dims)
# ) + input_batch_dims
# for i in range(len(target_batch_dims)):
# if padded_input[i] == 1 and i < len(target_batch_dims) - len(
# input_batch_dims
# ):
# # This dimension was added by padding (broadcasted from non-existent to size 1 or more)
# broadcasted_axes.append(i - len(target_batch_dims))
# elif padded_input[i] == 1 and target_batch_dims[i] > 1:
# # This dimension was broadcasted from size 1 to larger size
# broadcasted_axes.append(i - len(target_batch_dims))
# elif padded_input[i] != target_batch_dims[i] and padded_input[i] != 1:
# raise ValueError(
# f"Cannot broadcast {input_batch_dims} to {target_batch_dims}"
# )
# return broadcasted_axes
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# output.tensor_value = ops.broadcast_to(
# args[0], self.target_batch_dims + output.shape
# )
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# np_result = np.broadcast_to(
# args[0].to_numpy(), shape=self.target_batch_dims + output.shape
# )
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# from .reduce import sum_batch_dims
# broadcasted_axes = self.get_broadcasted_axes(
# primals[0].batch_dims, output.batch_dims
# )
# return [sum_batch_dims(cotangent, axes=broadcasted_axes, keep_dims=True)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# return broadcast_batch_dims(tangents[0], self.target_batch_dims)
# def broadcast_batch_dims(arg: Array, batch_dims: Shape) -> Array:
# """Broadcast array to target batch_dims."""
# if arg.batch_dims == batch_dims:
# return arg
# for _ in range(len(batch_dims) - len(arg.batch_dims)):
# arg = unsqueeze_batch_dims(arg, [0])
# op = BroadcastBatchDimsOp(batch_dims)
# return op.forward(arg)
# class SqueezeOp(ViewOperation):
# """Squeeze operation to remove dimensions of size 1."""
# def __init__(self, axes: list[int] | None = None):
# super().__init__(f"squeeze[axes={axes}]")
# self.axes = sorted(axes) if axes is not None else []
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compatible signature."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Squeeze operation requires 1 input shape, got {len(input_shapes)}"
# )
# input_shape = input_shapes[0]
# new_shape = list(input_shape)
# for ax in self.axes:
# if ax < -len(new_shape) or ax >= len(new_shape):
# raise ValueError(f"Axis {ax} is out of bounds for squeeze operation")
# if input_shape[ax] == 1:
# new_shape[ax] = None
# else:
# raise ValueError(
# f"Cannot squeeze axis {ax} of size {input_shape[ax]} (must be 1)"
# )
# new_shape = [dim for dim in new_shape if dim is not None]
# return tuple(new_shape)
# def forward(self, *args: Array) -> Array:
# """Override forward to handle case where no squeezing needed with compatible signature."""
# if len(args) != 1:
# raise ValueError(f"Squeeze operation requires 1 argument, got {len(args)}")
# return super().forward(*args)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# res = args[0]
# # Use self.axes directly since it's already normalized to a list in __init__
# for ax in self.axes:
# res = ops.squeeze(res, ax)
# output.tensor_value = res
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# axis = tuple(self.axes) if self.axes else None
# np_result = np.squeeze(args[0].to_numpy(), axis=axis)
# output.impl_(np_result)
# def vjp_rule(
# self, _primals: list[Array], cotangent: Array, _output: Array
# ) -> list[Array]:
# return [unsqueeze(cotangent, self.axes)]
# def jvp_rule(
# self, _primals: list[Array], tangents: list[Array], _output: Array
# ) -> Array:
# return squeeze(tangents[0], self.axes)
# def squeeze(arg: Array, axes: list[int] | None = None) -> Array:
# """Squeeze array by removing dimensions of size 1."""
# if axes is None:
# return arg
# axes = [ax if ax < 0 else -len(arg.shape) + ax for ax in axes]
# op = SqueezeOp(axes)
# res = op.forward(arg)
# return res
# class UnsqueezeOp(ViewOperation):
# """Unsqueeze operation to add dimensions of size 1."""
# def __init__(self, axes: list[int] | None = None):
# super().__init__(f"unsqueeze[axes={axes}]")
# self.axes = sorted(axes) if axes is not None else []
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compatible signature."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Unsqueeze operation requires 1 input shape, got {len(input_shapes)}"
# )
# input_shape = input_shapes[0]
# new_shape = list(input_shape)
# for ax in self.axes:
# if ax < -len(new_shape) - 1:
# raise ValueError(f"Axis {ax} is out of bounds for unsqueeze operation")
# if ax + 1 <= -1:
# new_shape.insert(ax + 1, 1)
# else:
# new_shape.append(1)
# return tuple(new_shape)
# def forward(self, *args: Array) -> Array:
# """Override forward to handle case where no unsqueezing needed with compatible signature."""
# if len(args) != 1:
# raise ValueError(
# f"Unsqueeze operation requires 1 argument, got {len(args)}"
# )
# return super().forward(*args)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# res_value = args[0]
# for ax in self.axes:
# res_value = ops.unsqueeze(res_value, ax)
# output.tensor_value = res_value
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# np_result = np.expand_dims(args[0].to_numpy(), axis=self.axes)
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# return [squeeze(cotangent, self.axes)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# return unsqueeze(tangents[0], self.axes)
# def unsqueeze(arg: Array, axes: list[int] | None = None) -> Array:
# """Unsqueeze array by adding dimensions of size 1."""
# if axes is None:
# return arg
# axes = [ax if ax < 0 else -len(arg.shape) - 1 + ax for ax in axes]
# op = UnsqueezeOp(axes)
# return op.forward(arg)
# class ShallowCopyOp(ViewOperation):
# """Copy operation to create a new array with the same data."""
# def __init__(self, arg: Array):
# if not arg.name and arg.impl and arg.shape == () and arg.batch_dims == ():
# name = arg.to_numpy().__repr__() # Use numpy repr for empty arrays
# else:
# name = "shallow_copy"
# super().__init__(name)
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compatible signature."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Copy operation requires 1 input shape, got {len(input_shapes)}"
# )
# return input_shapes[0]
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# output.tensor_value = args[0]
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# output.impl = args[0].impl
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# return [cotangent]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# return tangents[0]
# def shallow_copy(arg: Array) -> Array:
# """Create a shallow copy of the array."""
# op = ShallowCopyOp(arg)
# return op.forward(arg)
# class ConcatenateOp(Operation):
# """Concatenate operation to join arrays along an existing axis."""
# def __init__(self, axis: int = 0):
# super().__init__(f"concatenate[axis={axis}]")
# self.axis = axis
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compute output shape for concatenate operation."""
# if len(input_shapes) == 0:
# raise ValueError("Concatenate operation requires at least 1 input")
# # All input shapes must be the same except along the concatenation axis
# first_shape = input_shapes[0]
# if not first_shape:
# raise ValueError("Cannot concatenate empty shapes")
# # Normalize axis
# axis = self.axis if self.axis >= 0 else len(first_shape) + self.axis
# if axis < 0 or axis >= len(first_shape):
# raise ValueError(
# f"Axis {self.axis} is out of bounds for array with {len(first_shape)} dimensions"
# )
# # Check that all shapes are compatible
# total_size_along_axis = 0
# for i, shape in enumerate(input_shapes):
# if len(shape) != len(first_shape):
# raise ValueError(
# f"All inputs must have the same number of dimensions for concatenate operation. "
# f"Input 0 has {len(first_shape)} dimensions, input {i} has {len(shape)} dimensions"
# )
# for j, (dim1, dim2) in enumerate(zip(first_shape, shape, strict=False)):
# if j != axis and dim1 != dim2:
# raise ValueError(
# f"All inputs must have the same shape except along axis {axis}. "
# f"Input 0 has shape {first_shape}, input {i} has shape {shape}"
# )
# total_size_along_axis += shape[axis]
# # Compute output shape
# output_shape = list(first_shape)
# output_shape[axis] = total_size_along_axis
# return tuple(output_shape)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# """MAX graph implementation using ops.concat."""
# # Normalize axis for MAX operations, considering batch_dims
# # full_output_shape = output.batch_dims + output.shape # TODO: Use if needed
# axis = self.axis if self.axis >= 0 else len(output.shape) + self.axis
# # Adjust axis to account for batch_dims in the actual tensor
# axis_in_tensor = axis + len(output.batch_dims)
# output.tensor_value = ops.concat(args, axis=axis_in_tensor)
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# """Eager execution using NumPy concatenate."""
# import numpy as np
# numpy_arrays = [arg.to_numpy() for arg in args]
# # Normalize axis for NumPy operations, considering batch_dims
# axis = self.axis if self.axis >= 0 else len(output.shape) + self.axis
# # Adjust axis to account for batch_dims in the actual tensor
# axis_in_tensor = axis + len(output.batch_dims)
# result = np.concatenate(numpy_arrays, axis=axis_in_tensor)
# output.impl_(result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# """Vector-Jacobian product rule for concatenate operation.
# The VJP of concatenate is slicing the cotangent back into pieces.
# """
# # Normalize axis
# axis = self.axis if self.axis >= 0 else len(cotangent.shape) + self.axis
# # Split the cotangent along the concatenated axis
# result = []
# start_idx = 0
# for primal in primals:
# size_along_axis = primal.shape[axis]
# end_idx = start_idx + size_along_axis
# # Create slice that selects this input's portion along the concatenated axis
# slices = [slice(None)] * len(cotangent.shape)
# slices[axis] = slice(start_idx, end_idx)
# # Slice the cotangent
# sliced = array_slice(cotangent, slices)
# result.append(sliced)
# start_idx = end_idx
# return result
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# """Jacobian-vector product rule for concatenate operation.
# The JVP of concatenate is concatenating the tangents along the same axis.
# """
# # Use the ConcatenateOp directly to avoid circular import
# op = ConcatenateOp(axis=self.axis)
# return op.forward(*tangents)
# def forward(self, *args: Array) -> Array:
# """Forward pass for concatenate operation with multiple inputs."""
# if len(args) == 0:
# raise ValueError("Concatenate operation requires at least 1 argument")
# # Move arrays to best device
# from .operation import move_to_best_device
# args = move_to_best_device(*args)
# # Validate inputs have compatible properties
# first_arg = args[0]
# for _i, arg in enumerate(args[1:], 1):
# if arg.dtype != first_arg.dtype:
# raise ValueError(
# f"All inputs must have the same dtype. Got {arg.dtype} vs {first_arg.dtype}"
# )
# if arg.device != first_arg.device:
# raise ValueError(
# f"All inputs must be on the same device. Got {arg.device} vs {first_arg.device}"
# )
# # Compute output properties
# input_shapes = [arg.shape for arg in args]
# output_shape = self.compute_output_shape(*input_shapes)
# # All inputs should have the same batch_dims
# output_batch_dims = first_arg.batch_dims
# for i, arg in enumerate(args[1:], 1):
# if arg.batch_dims != output_batch_dims:
# raise ValueError(
# f"All inputs must have the same batch_dims for concatenate operation. "
# f"Input 0 has batch_dims {output_batch_dims}, input {i} has batch_dims {arg.batch_dims}"
# )
# # Create result array
# res = Array(
# shape=output_shape,
# dtype=first_arg.dtype,
# device=first_arg.device,
# materialize=False,
# name=self.name,
# batch_dims=output_batch_dims,
# )
# # Set up computation
# res.set_maxpr(self.maxpr)
# res.add_arguments(*args)
# res.vjp_rule = self.vjp_rule
# res.jvp_rule = self.jvp_rule
# # Execute eager computation if needed
# if not res.stage_realization:
# self.eagerxpr(list(args), res)
# return res
# def concatenate(args: list[Array], axis: int = 0) -> Array:
# """Concatenate arrays along an existing axis.
# Args:
# args: List of arrays to concatenate
# axis: Axis along which to concatenate arrays (default: 0)
# Returns:
# Concatenated array
# """
# if not args:
# raise ValueError("Concatenate operation requires at least one array")
# op = ConcatenateOp(axis)
# return op.forward(*args)
# class ArraySliceOp(ViewOperation):
# """Array slicing operation."""
# def __init__(self, slices: list[slice], squeeze_axes: list[int] | None = None):
# # Store original slices for reference
# self.original_slices = slices.copy()
# # Check if we have negative steps - if so, we'll need special handling
# self.has_negative_steps = any(s.step is not None and s.step < 0 for s in slices)
# # Convert slices to a more manageable format
# slice_strs = []
# for s in slices:
# start = s.start if s.start is not None else ""
# stop = s.stop if s.stop is not None else ""
# step = s.step if s.step is not None else ""
# if step and step != 1:
# slice_strs.append(f"{start}:{stop}:{step}")
# else:
# slice_strs.append(f"{start}:{stop}")
# squeeze_info = f"_squeeze{squeeze_axes}" if squeeze_axes else ""
# super().__init__(f"array_slice[{','.join(slice_strs)}]{squeeze_info}")
# self.slices = slices
# self.squeeze_axes = squeeze_axes or []
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compute output shape for array slice operation."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Array slice operation requires 1 input shape, got {len(input_shapes)}"
# )
# input_shape = input_shapes[0]
# output_shape = []
# # Process each dimension
# for i, dim_size in enumerate(input_shape):
# if i < len(self.slices):
# s = self.slices[i]
# start = s.start if s.start is not None else 0
# stop = s.stop if s.stop is not None else dim_size
# step = s.step if s.step is not None else 1
# # Handle negative indices
# if start < 0:
# start = max(0, dim_size + start)
# if stop < 0:
# stop = max(0, dim_size + stop)
# # Clamp to valid range
# start = max(0, min(start, dim_size))
# stop = max(start, min(stop, dim_size))
# # Calculate output size for this dimension
# if step > 0:
# output_size = max(0, (stop - start + step - 1) // step)
# elif step < 0:
# # Handle negative step - reverse direction
# # For negative step, we need start > stop (conceptually)
# # But we need to handle the actual range calculation
# if start >= stop:
# # For negative step with start >= stop, we go from start down to stop+1
# output_size = max(0, (start - stop + (-step) - 1) // (-step))
# else:
# # Invalid range for negative step
# output_size = 0
# else:
# raise ValueError("Step cannot be zero")
# # Skip this dimension if it should be squeezed (JAX-compatible behavior)
# if i not in self.squeeze_axes:
# output_shape.append(output_size)
# else:
# # No slice for this dimension, keep original size
# output_shape.append(dim_size)
# return tuple(output_shape)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# """MAX graph implementation using ops.slice_tensor."""
# # Check for negative steps - not supported in JIT mode yet
# if self.has_negative_steps:
# raise NotImplementedError(
# "Negative step slicing (e.g., [::-1]) is not yet supported in JIT-compiled functions "
# "due to MAX engine limitations. Use eager execution instead, or avoid negative steps "
# "in JIT-compiled code."
# )
# # Build slice indices for MAX ops.slice_tensor
# # Need to account for batch_dims - slicing only applies to shape dimensions
# slice_indices = []
# # Add full slices for batch dimensions
# for _ in range(len(output.batch_dims)):
# slice_indices.append(slice(None))
# # Add the actual slices for shape dimensions
# for i in range(len(self.slices)):
# s = self.slices[i]
# slice_indices.append(slice(s.start, s.stop, s.step))
# # Add full slices for any remaining dimensions
# for _ in range(len(self.slices), len(args[0].shape)):
# slice_indices.append(slice(None))
# # Apply the slicing
# result = ops.slice_tensor(args[0], slice_indices)
# # Apply squeezing for JAX-compatible behavior
# if self.squeeze_axes:
# # Adjust squeeze axes to account for batch dimensions
# squeeze_axes_adjusted = [
# ax + len(output.batch_dims) for ax in self.squeeze_axes
# ]
# for ax in sorted(
# squeeze_axes_adjusted, reverse=True
# ): # Squeeze in reverse order
# result = ops.squeeze(result, ax)
# output.tensor_value = result
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# """Eager execution using NumPy slicing."""
# input_array = args[0].to_numpy()
# # Build numpy slice tuple
# # Need to account for batch_dims - slicing only applies to shape dimensions
# numpy_slices = []
# # Add full slices for batch dimensions
# for _ in range(len(args[0].batch_dims)):
# numpy_slices.append(slice(None))
# # Add the actual slices for shape dimensions
# for i in range(len(args[0].shape)):
# if i < len(self.slices):
# numpy_slices.append(self.slices[i])
# else:
# numpy_slices.append(slice(None)) # Full slice for remaining dimensions
# result = input_array[tuple(numpy_slices)]
# # Apply squeezing for JAX-compatible behavior
# if self.squeeze_axes:
# # Adjust squeeze axes to account for batch dimensions
# squeeze_axes_adjusted = [
# ax + len(args[0].batch_dims) for ax in self.squeeze_axes
# ]
# result = np.squeeze(result, axis=tuple(squeeze_axes_adjusted))
# output.impl_(result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# """Vector-Jacobian product rule for array slice."""
# # If we squeezed dimensions, we need to unsqueeze the cotangent first
# if self.squeeze_axes:
# from ..ops.view import unsqueeze
# # Unsqueeze in the positions that were squeezed
# unsqueeze_axes = self.squeeze_axes.copy()
# cotangent_unsqueezed = unsqueeze(cotangent, unsqueeze_axes)
# else:
# cotangent_unsqueezed = cotangent
# return [pad(cotangent_unsqueezed, self.slices, primals[0].shape)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# """Jacobian-vector product rule for array slice."""
# # Apply the same slicing and squeezing to tangents
# op = ArraySliceOp(self.slices, self.squeeze_axes)
# return op.forward(tangents[0])
# def array_slice(
# arg: Array, slices: list[slice], squeeze_axes: list[int] | None = None
# ) -> Array:
# """Slice an array along specified dimensions.
# Args:
# arg: Input array to slice
# slices: List of slice objects defining the slicing for each dimension
# squeeze_axes: List of axes that should be squeezed (for JAX compatibility)
# Returns:
# Sliced array
# """
# op = ArraySliceOp(slices, squeeze_axes)
# return op.forward(arg)
# def split(arg: Array, sizes: list[int], axis: int = 0) -> list[Array]:
# """Split an array into multiple sub-arrays along a specified axis.
# Args:
# arg: Input array to split
# sizes: List of sizes for each split along the specified axis
# axis: Axis along which to split the array (default: 0)
# Returns:
# List of sub-arrays resulting from the split
# """
# if not sizes:
# raise ValueError("Sizes list must not be empty")
# if axis < 0:
# axis += len(arg.shape)
# if axis < 0 or axis >= len(arg.shape):
# raise ValueError(
# f"Axis {axis} is out of bounds for array with {len(arg.shape)} dimensions"
# )
# # Compute the total size along the specified axis
# total_size = sum(sizes)
# if total_size != arg.shape[axis]:
# raise ValueError(
# f"Total size {total_size} along axis {axis} does not match input shape {arg.shape[axis]}"
# )
# # Create slices for each split
# slices = []
# idx = 0
# for size in sizes:
# slices.append(slice(idx, idx + size))
# idx += size
# # Create the result arrays
# results = []
# for s in slices:
# slice_obj = [slice(None)] * len(arg.shape) # Full slice for all dimensions
# slice_obj[axis] = s # Set the slice for the specified axis
# results.append(array_slice(arg, slice_obj))
# return results
# class PadOp(Operation):
# """Inverse slice operation - places a smaller array into a larger zero-filled array."""
# def __init__(self, slices: list[slice], target_shape: Shape):
# # Convert slices to string representation for name
# slice_strs = []
# for s in slices:
# start = s.start if s.start is not None else ""
# stop = s.stop if s.stop is not None else ""
# step = s.step if s.step is not None else ""
# if step and step != 1:
# slice_strs.append(f"{start}:{stop}:{step}")
# else:
# slice_strs.append(f"{start}:{stop}")
# super().__init__(f"pad[{','.join(slice_strs)}]")
# self.slices = slices
# self.target_shape = target_shape
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Compute output shape for inverse slice operation."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Inverse slice operation requires 1 input shape, got {len(input_shapes)}"
# )
# # Validate that applying slices to target_shape would yield input_shape
# input_shape = input_shapes[0]
# # Simulate slicing target_shape with self.slices to verify consistency
# expected_shape = []
# for i, dim_size in enumerate(self.target_shape):
# if i < len(self.slices):
# s = self.slices[i]
# start = s.start if s.start is not None else 0
# stop = s.stop if s.stop is not None else dim_size
# step = s.step if s.step is not None else 1
# # Handle step sizes - now supported!
# # if step != 1:
# # raise NotImplementedError(
# # "Stepped slicing not yet supported in pad"
# # )
# # Handle negative indices
# if start < 0:
# start = max(0, dim_size + start)
# if stop < 0:
# stop = max(0, dim_size + stop)
# # Clamp to valid range
# start = max(0, min(start, dim_size))
# stop = max(start, min(stop, dim_size))
# # Calculate output size for this dimension, accounting for step
# if step == 1:
# output_size = stop - start
# else:
# # For stepped slicing: number of elements = ceil((stop - start) / step)
# output_size = (stop - start + step - 1) // step
# expected_shape.append(output_size)
# else:
# # No slice for this dimension, keep original size
# expected_shape.append(dim_size)
# expected_shape = tuple(expected_shape)
# if expected_shape != input_shape:
# raise ValueError(
# f"Slicing target_shape {self.target_shape} with {self.slices} "
# f"would produce shape {expected_shape}, but input has shape {input_shape}"
# )
# return self.target_shape
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# """MAX graph implementation using range->reshape->broadcast->slice->scatter approach."""
# import numpy as np
# input_tensor = args[0]
# # Step 1: Calculate total elements in output shape
# total_elements = int(np.prod(output.shape))
# # Step 2: Create flat index tensor using ops.range with int32 dtype
# flat_indices = ops.range(0, total_elements, 1, dtype=DType.int32)
# # Step 3: Reshape to output shape
# reshaped_indices = ops.reshape(flat_indices, output.shape)
# # Step 4: Broadcast to include batch dims if needed
# if output.batch_dims:
# # Need to broadcast to batch_dims + output.shape
# target_shape = list(output.batch_dims) + list(output.shape)
# broadcasted_indices = ops.broadcast_to(reshaped_indices, target_shape)
# else:
# broadcasted_indices = reshaped_indices
# # Step 5: Slice the index tensor using self.slices to get target indices
# slice_indices = []
# # Add full slices for batch dimensions
# for _ in range(len(output.batch_dims)):
# slice_indices.append(slice(None))
# # Add the actual slices for shape dimensions
# for s in self.slices:
# slice_indices.append(slice(s.start, s.stop, s.step))
# # Add full slices for any remaining dimensions
# for _ in range(len(self.slices), len(output.shape)):
# slice_indices.append(slice(None))
# # Slice to get the indices where input should go
# sliced_indices = ops.slice_tensor(broadcasted_indices, slice_indices)
# # Step 6: Flatten the sliced indices
# flattened_indices = ops.reshape(sliced_indices, [-1])
# # Step 7: Create flat zero tensor for scattering
# total_output_elements = int(
# np.prod(list(output.batch_dims) + list(output.shape))
# )
# from max.graph import DeviceRef
# zero_scalar = ops.constant(
# 0.0, dtype=output.dtype, device=DeviceRef.from_device(output.device)
# )
# flat_zeros = ops.broadcast_to(zero_scalar, [total_output_elements])
# # Step 8: Flatten input tensor
# input_flattened = ops.reshape(input_tensor, [-1])
# # Step 9: Use scatter to place input values at target indices
# # scatter(input, updates, indices, axis) - scatter along axis=0 (first axis) of flat tensor
# scattered_flat = ops.scatter(
# flat_zeros, input_flattened, flattened_indices, axis=0
# )
# # Step 10: Reshape result back to target shape
# final_shape = list(output.batch_dims) + list(output.shape)
# output.tensor_value = ops.reshape(scattered_flat, final_shape)
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# """Eager execution using NumPy."""
# small_array = args[0]
# # Create zero-filled target array
# target_shape = output.batch_dims + output.shape
# result_np = np.zeros(target_shape, dtype=small_array.to_numpy().dtype)
# # Build slice indices (accounting for batch_dims)
# slice_indices = []
# # Add full slices for batch dimensions
# for _ in range(len(output.batch_dims)):
# slice_indices.append(slice(None))
# # Add the actual slices for shape dimensions
# slice_indices.extend(self.slices)
# # Add full slices for any remaining dimensions
# for _i in range(len(self.slices), len(output.shape)):
# slice_indices.append(slice(None))
# # Place small array into the target location
# result_np[tuple(slice_indices)] = small_array.to_numpy()
# output.impl_(result_np)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# """VJP rule: slice the cotangent back to original size."""
# # The VJP of pad is just a regular slice!
# from nabla.ops.view import array_slice
# return [array_slice(cotangent, self.slices)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# """JVP rule: apply pad to tangents."""
# return pad(tangents[0], self.slices, self.target_shape)
# def forward(self, *args: Array) -> Array:
# """Forward pass for inverse slice operation."""
# if len(args) != 1:
# raise ValueError(
# f"Inverse slice operation requires 1 argument, got {len(args)}"
# )
# input_array = args[0]
# # Compute output properties
# output_shape = self.compute_output_shape(input_array.shape)
# # Create result array
# res = Array(
# shape=output_shape,
# dtype=input_array.dtype,
# device=input_array.device,
# materialize=False,
# name=self.name,
# batch_dims=input_array.batch_dims,
# )
# # Set up computation
# res.set_maxpr(self.maxpr)
# res.add_arguments(input_array)
# res.vjp_rule = self.vjp_rule
# res.jvp_rule = self.jvp_rule
# # Execute eager computation if needed
# if not res.stage_realization:
# self.eagerxpr([input_array], res)
# return res
# def pad(arg: Array, slices: list[slice], target_shape: Shape) -> Array:
# """Place a smaller array into a larger zero-filled array at the location specified by slices.
# This is the inverse operation of array slicing - given slices, a small array, and target shape,
# it creates a larger array where the small array is placed at the sliced location
# and everything else is zero.
# Args:
# arg: Input array (the smaller array to be placed)
# slices: List of slice objects defining where to place the array
# target_shape: The shape of the output array
# Returns:
# Larger array with input placed at sliced location, zeros elsewhere
# """
# op = PadOp(slices, target_shape)
# return op.forward(arg)
# class SqueezeBatchDimsOp(ViewOperation):
# """Squeeze operation to remove batch dimensions of size 1."""
# def __init__(self, axes: list[int] | None = None):
# super().__init__(f"squeeze_batch_dims[axes={axes}]")
# self.axes = sorted(axes) if axes is not None else []
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Shape stays the same for batch dimension operations."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Squeeze batch dims operation requires 1 input shape, got {len(input_shapes)}"
# )
# return input_shapes[0]
# def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
# """Compute output batch_dims for squeeze operation."""
# if len(input_batch_dimss) != 1:
# raise ValueError(
# f"Squeeze batch dims operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
# )
# input_batch_dims = input_batch_dimss[0]
# new_batch_dims = list(input_batch_dims)
# for ax in self.axes:
# if ax < -len(new_batch_dims) or ax >= len(new_batch_dims):
# raise ValueError(
# f"Axis {ax} is out of bounds for squeeze batch dims operation"
# )
# if input_batch_dims[ax] == 1:
# new_batch_dims[ax] = None
# else:
# raise ValueError(
# f"Cannot squeeze batch axis {ax} of size {input_batch_dims[ax]} (must be 1)"
# )
# new_batch_dims = [dim for dim in new_batch_dims if dim is not None]
# return tuple(new_batch_dims)
# def forward(self, *args: Array) -> Array:
# """Override forward to handle case where no squeezing needed."""
# if len(args) != 1:
# raise ValueError(
# f"Squeeze batch dims operation requires 1 argument, got {len(args)}"
# )
# return super().forward(*args)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# """MAX graph implementation using ops.squeeze."""
# axes = [ax - len(output.shape) for ax in self.axes]
# res = args[0]
# for ax in axes:
# res = ops.squeeze(res, ax)
# output.tensor_value = res
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# """Eager execution using NumPy squeeze."""
# axes = [ax - len(args[0].shape) for ax in self.axes]
# np_result = np.squeeze(args[0].to_numpy(), axis=tuple(axes))
# output.impl_(np_result)
# def vjp_rule(
# self, _primals: list[Array], cotangent: Array, _output: Array
# ) -> list[Array]:
# """VJP rule: unsqueeze the cotangent back to original batch dimensions."""
# return [unsqueeze_batch_dims(cotangent, self.axes)]
# def jvp_rule(
# self, _primals: list[Array], tangents: list[Array], _output: Array
# ) -> Array:
# """JVP rule: apply squeeze to tangents."""
# return squeeze_batch_dims(tangents[0], self.axes)
# def squeeze_batch_dims(arg: Array, axes: list[int] | None = None) -> Array:
# """Squeeze array by removing batch dimensions of size 1.
# Args:
# arg: Input array
# axes: List of batch dimension axes to squeeze. If None, returns array unchanged.
# Returns:
# Array with specified batch dimensions of size 1 removed
# """
# if axes is None:
# return arg
# # Convert to negative indices for consistency with batch dimension handling
# axes = [ax if ax < 0 else -len(arg.batch_dims) + ax for ax in axes]
# op = SqueezeBatchDimsOp(axes)
# return op.forward(arg)
# class UnsqueezeBatchDimsOp(ViewOperation):
# """Unsqueeze operation to add batch dimensions of size 1."""
# def __init__(self, axes: list[int] | None = None):
# super().__init__(f"unsqueeze_batch_dims[axes={axes}]")
# self.axes = sorted(axes) if axes is not None else []
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# """Shape stays the same for batch dimension operations."""
# if len(input_shapes) != 1:
# raise ValueError(
# f"Unsqueeze batch dims operation requires 1 input shape, got {len(input_shapes)}"
# )
# return input_shapes[0]
# def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
# """Compute output batch_dims for unsqueeze operation."""
# if len(input_batch_dimss) != 1:
# raise ValueError(
# f"Unsqueeze batch dims operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
# )
# input_batch_dims = input_batch_dimss[0]
# new_batch_dims = list(input_batch_dims)
# for ax in self.axes:
# if ax < -len(new_batch_dims) - 1:
# raise ValueError(
# f"Axis {ax} is out of bounds for unsqueeze batch dims operation"
# )
# if ax + 1 <= -1:
# new_batch_dims.insert(ax + 1, 1)
# else:
# new_batch_dims.append(1)
# return tuple(new_batch_dims)
# def forward(self, *args: Array) -> Array:
# """Override forward to handle case where no unsqueezing needed."""
# if len(args) != 1:
# raise ValueError(
# f"Unsqueeze batch dims operation requires 1 argument, got {len(args)}"
# )
# return super().forward(*args)
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# """MAX graph implementation using ops.unsqueeze."""
# res = args[0]
# # Use self.axes directly since it's already normalized to a list in __init__
# # Adjust axes for batch dimensions
# axes = [ax - len(output.shape) for ax in self.axes] if self.axes else []
# for ax in axes:
# res = ops.unsqueeze(res, ax)
# output.tensor_value = res
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# """Eager execution using NumPy expand_dims."""
# if self.axes:
# # Apply expand_dims for each axis sequentially
# np_result = args[0].to_numpy()
# axes = [ax - len(args[0].shape) for ax in self.axes]
# for ax in axes:
# np_result = np.expand_dims(np_result, axis=ax)
# else:
# np_result = args[0].to_numpy()
# output.impl_(np_result)
# def vjp_rule(
# self, primals: list[Array], cotangent: Array, output: Array
# ) -> list[Array]:
# """VJP rule: squeeze the cotangent back to original batch dimensions."""
# return [squeeze_batch_dims(cotangent, self.axes)]
# def jvp_rule(
# self, primals: list[Array], tangents: list[Array], output: Array
# ) -> Array:
# """JVP rule: apply unsqueeze to tangents."""
# return unsqueeze_batch_dims(tangents[0], self.axes)
# def unsqueeze_batch_dims(arg: Array, axes: list[int] | None = None) -> Array:
# """Unsqueeze array by adding batch dimensions of size 1.
# Args:
# arg: Input array
# axes: List of positions where to insert batch dimensions of size 1.
# If None, returns array unchanged.
# Returns:
# Array with batch dimensions of size 1 added at specified positions
# """
# if axes is None:
# return arg
# # Convert to negative indices for consistency with batch dimension handling
# axes = [ax if ax < 0 else -len(arg.batch_dims) - 1 + ax for ax in axes]
# op = UnsqueezeBatchDimsOp(axes)
# return op.forward(arg)
# # let's creata stack function which first creates a lsit of arrays wiht a new axis (via unsqueeze) and then concatenates them along that axis
# def stack(arrays: list[Array], axis: int = 0) -> Array:
# """Stack arrays along a new axis.
# Args:
# arrays: List of arrays to stack
# axis: Axis along which to stack the arrays (default: 0)
# Returns:
# Stacked array
# """
# if not arrays:
# raise ValueError("Stack operation requires at least one array")
# # Unsqueeze each array to add a new dimension at the specified axis
# unsqueezed_arrays = [unsqueeze(array, [axis]) for array in arrays]
# # Use concatenate to stack them along the new axis
# return concatenate(unsqueezed_arrays, axis=axis)
# # class GatherOp(Operation):
# # def __init__(self, axis: int = -1):
# # """
# # Initialize take operation.
# # Args:
# # axis: The dimension which indices indexes from input.
# # If negative, indexes relative to the end of the input tensor.
# # """
# # self.axis = axis
# # def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# # """
# # Compute the output shape for take operation.
# # The output shape replaces the indexed dimension with the indices shape.
# # Args:
# # input_shapes: (input_shape, indices_shape)
# # Returns:
# # Output shape tuple
# # """
# # input_shape, indices_shape = input_shapes
# # # Normalize negative axis
# # axis = self.axis
# # if axis < 0:
# # axis += len(input_shape)
# # if axis < 0 or axis >= len(input_shape):
# # raise ValueError(f"Axis {self.axis} is out of bounds for input with {len(input_shape)} dimensions")
# # # Output shape: input_shape with axis dimension replaced by indices_shape
# # output_shape = (
# # input_shape[:axis] +
# # indices_shape +
# # input_shape[axis + 1:]
# # )
# # return output_shape
# # def compute_output_batch_dims(self, *input_batch_dims: tuple) -> tuple:
# # """
# # Compute output batch dims for gather operation.
# # For gather with vmap, the output batch dims should be the broadcasted
# # batch dimensions of both input array and indices.
# # Args:
# # input_batch_dims: (input_batch_dims, indices_batch_dims)
# # Returns:
# # Broadcasted batch dims of input array and indices
# # """
# # if len(input_batch_dims) != 2:
# # raise ValueError(f"Gather operation requires 2 input batch dims, got {len(input_batch_dims)}")
# # input_batch_dims_val, indices_batch_dims_val = input_batch_dims[0], input_batch_dims[1]
# # # Use the standard broadcasting logic for batch dimensions
# # from ..utils.shape_utils import get_broadcasted_shape
# # return get_broadcasted_shape(input_batch_dims_val, indices_batch_dims_val)
# # def maxpr(self, args: list[TensorValue], output: Array) -> None:
# # """
# # MAX graph implementation using max.graph.ops.gather.
# # Args:
# # args: [input_tensor, indices_tensor]
# # output: Output array to store result
# # """
# # input_tensor, indices_tensor = args
# # # Import MAX ops
# # from max.graph import ops
# # # Ensure indices are integers for MAX
# # if indices_tensor.type.dtype.name != 'int64':
# # indices_tensor = ops.cast(indices_tensor, ops.DType.int64)
# # # Convert logical axis to physical axis (accounting for batch dimensions)
# # # The axis parameter refers to the logical shape, but the actual tensor includes batch dims
# # batch_offset = len(output.batch_dims)
# # logical_rank = len(output.shape)
# # # Normalize the axis relative to logical shape
# # if self.axis < 0:
# # logical_axis = self.axis + logical_rank
# # else:
# # logical_axis = self.axis
# # # Convert to physical axis in the full tensor
# # physical_axis = logical_axis + batch_offset
# # # Use MAX's gather operation
# # result = ops.gather(input_tensor, indices_tensor, axis=physical_axis)
# # output.tensor_value = result
# # def eagerxpr(self, args: list[Array], output: Array) -> None:
# # # the axis is awlays negative, so no need to convert it
# # input_array = args[0].to_numpy()
# # indices_array = args[1].to_numpy()
# # if indices_array.dtype.kind not in {'i', 'u'}:
# # raise ValueError(
# # f"Indices array must be of integer type, got {indices_array.dtype}"
# # )
# # # Use numpy's advanced indexing to gather values
# # result_np = input_array[tuple(indices_array)]
# # output.impl_(result_np)
# # def vjp_rule(self, primals: list[Array], cotangent: Array, output: Array) -> list[Array]:
# # input_array, indices_array = primals
# # target_shape = input_array.shape
# # input_grad = scatter(
# # target_shape=target_shape,
# # indices=indices_array,
# # values=cotangent,
# # axis=self.axis
# # )
# # # Indices don't need gradients, but we need to return a zero array of the same shape
# # from ..ops.creation import zeros
# # indices_grad = zeros(indices_array.shape, dtype=input_array.dtype)
# # return [input_grad, indices_grad]
# # def jvp_rule(self, primals: list[Array], tangents: list[Array], output: Array) -> Array:
# # input_tangent, indices_tangent = tangents
# # return gather(input_tangent, indices=primals[1], axis=self.axis)
# # def compute_output_dtype(self, arg1: Array, arg2: Array) -> DType:
# # """Default: output dtype same as first input dtype."""
# # return arg1.dtype
# # def forward(self, *args: Array) -> Array:
# # if len(args) != 2:
# # raise ValueError(f"Scatter operation requires 2 arguments, got {len(args)}")
# # # Move arrays to best device (like BinaryOperation does)
# # from .operation import move_to_best_device
# # args = move_to_best_device(*args)
# # indices, values = args
# # # compute shape difference length wise
# # shape_diff = len(values.shape) - len(indices.shape)
# # if shape_diff < 0:
# # raise ValueError(
# # f"Indices shape {indices.shape} cannot be larger than values shape {values.shape}"
# # )
# # elif shape_diff > 0:
# # indices = broadcast_to(indices, values.shape[:shape_diff] + indices.shape)
# # output_shape = self.compute_output_shape(indices.shape, values.shape)
# # output_batch_dims = self.compute_output_batch_dims(
# # indices.batch_dims, values.batch_dims
# # )
# # output_dtype = self.compute_output_dtype(indices, values)
# # res = Array(
# # shape=output_shape,
# # dtype=output_dtype,
# # device=values.device,
# # materialize=False,
# # name=self.name,
# # batch_dims=output_batch_dims,
# # )
# # res.set_maxpr(self.maxpr)
# # res.add_arguments(indices, values)
# # res.vjp_rule = self.vjp_rule
# # res.jvp_rule = self.jvp_rule
# # res.custom_kernel_path = self.custom_kernel_path()
# # if not res.stage_realization:
# # self.eagerxpr([indices, values], res)
# # return res
# # def gather(input_array: Array, indices: Array, axis: int = -1) -> Array:
# # if axis >= 0:
# # # make negative
# # axis = axis - len(input_array.shape)
# # op = GatherOp(axis)
# # return op.forward(input_array, indices)
# # class ScatterOp(Operation):
# # def __init__(self, target_shape: tuple, axis: int = -1):
# # """
# # Initialize scatter operation.
# # Args:
# # target_shape: Shape of the output tensor
# # axis: The dimension along which to scatter indices
# # """
# # self.target_shape = target_shape
# # self.axis = axis
# # def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# # """
# # Compute the output shape for give operation.
# # Args:
# # input_shapes: (indices_shape, values_shape)
# # Returns:
# # target_shape (fixed by constructor)
# # """
# # # Convert Array objects to plain integers if needed (for JIT compatibility)
# # shape_list = []
# # for dim in self.target_shape:
# # if hasattr(dim, 'to_numpy'):
# # # It's an Array object, convert to scalar
# # shape_list.append(int(dim.to_numpy().item()))
# # else:
# # # It's already a plain integer
# # shape_list.append(int(dim))
# # return tuple(shape_list)
# # def compute_output_batch_dims(self, *input_batch_dims: tuple) -> tuple:
# # """
# # Compute output batch dims for scatter operation.
# # Args:
# # input_batch_dims: (indices_batch_dims, values_batch_dims)
# # Returns:
# # Broadcasted batch dims
# # """
# # if len(input_batch_dims) != 2:
# # raise ValueError(f"Scatter operation requires 2 input batch dims, got {len(input_batch_dims)}")
# # indices_batch_dims, values_batch_dims = input_batch_dims[0], input_batch_dims[1]
# # from ..utils.shape_utils import get_broadcasted_shape
# # return get_broadcasted_shape(indices_batch_dims, values_batch_dims)
# # def maxpr(self, args: list[TensorValue], output: Array) -> None:
# # """
# # MAX graph implementation using max.graph.ops.scatter.
# # For MAX scatter, we need to create a zero tensor and then scatter into it.
# # Args:
# # args: [indices_tensor, values_tensor]
# # output: Output array to store result
# # """
# # indices_tensor, values_tensor = args
# # from max.graph import ops
# # # Convert logical axis to physical axis (accounting for batch dimensions)
# # batch_offset = len(output.batch_dims)
# # logical_rank = len(self.target_shape)
# # # Normalize the axis relative to logical target shape
# # if self.axis < 0:
# # logical_axis = self.axis + logical_rank
# # else:
# # logical_axis = self.axis
# # # Convert to physical axis in the full tensor
# # physical_axis = logical_axis + batch_offset
# # # Create zero tensor with full shape (batch_dims + target_shape)
# # # Convert to plain integers in case we have Array objects
# # batch_dims_ints = tuple(int(d) for d in output.batch_dims)
# # target_shape_ints = tuple(
# # int(d.to_numpy()) if hasattr(d, 'to_numpy') else int(d)
# # for d in self.target_shape
# # )
# # full_target_shape = batch_dims_ints + target_shape_ints
# # zero_tensor = ops.broadcast_to(
# # ops.constant(0, dtype=values_tensor.dtype, device=values_tensor.device),
# # full_target_shape
# # )
# # # Use MAX's scatter_nd operation for flexible indexing
# # # scatter_nd expects indices with shape [num_updates, k] where k is the number of
# # # dimensions we're indexing into (for partial indexing, k < rank)
# # if physical_axis == 0:
# # # For axis=0, reshape indices from [N] to [N, 1]
# # # This means "N updates, each specifying 1 coordinate (along axis 0)"
# # indices_reshaped = ops.unsqueeze(indices_tensor, axis=-1)
# # result = ops.scatter_nd(zero_tensor, values_tensor, indices_reshaped)
# # else:
# # # For other axes, create proper index coordinates for scatter_nd
# # # We need to create indices that specify the full coordinates
# # # For now, let's try a simpler approach using scatter_nd with proper reshaping
# # # Create indices for scatter_nd: [num_updates, 1] format for single-axis indexing
# # indices_reshaped = ops.unsqueeze(indices_tensor, axis=-1)
# # # For non-zero axes, we need to use scatter, but we need compatible shapes
# # # Let's try scatter_nd by creating appropriate multi-dimensional indices
# # if physical_axis == 1:
# # # For axis=1, we're indexing into the second dimension
# # # We need to handle this case specifically
# # # For now, fall back to regular scatter but handle rank mismatch
# # try:
# # result = ops.scatter_nd(zero_tensor, values_tensor, indices_reshaped)
# # except Exception:
# # # If scatter_nd fails, try the old scatter approach
# # result = ops.scatter(zero_tensor, values_tensor, indices_tensor, axis=physical_axis)
# # else:
# # # For other axes, use scatter_nd with single-dimension indexing
# # result = ops.scatter_nd(zero_tensor, values_tensor, indices_reshaped)
# # output.tensor_value = result
# # def eagerxpr(self, args: list[Array], output: Array) -> None:
# # pass
# # def vjp_rule(self, primals: list[Array], cotangent: Array, output: Array) -> list[Array]:
# # """
# # Vector-Jacobian product rule for give operation.
# # Args:
# # primals: [indices, values] - the forward pass inputs
# # cotangent: Gradient flowing back from output
# # output: Forward pass output (for reference)
# # Returns:
# # [indices_grad, values_grad] where indices_grad is zero array
# # """
# # indices_array, values_array = primals
# # # Indices don't need gradients, but we need to return a zero array of the same shape
# # from ..ops.creation import zeros
# # indices_grad = zeros(indices_array.shape, dtype=values_array.dtype) # Use values dtype
# # # Values gradient: gather the cotangent at the same indices
# # values_grad = gather(cotangent, indices_array, axis=self.axis)
# # return [indices_grad, values_grad]
# # def jvp_rule(self, primals: list[Array], tangents: list[Array], output: Array) -> Array:
# # """
# # Jacobian-vector product rule for give operation.
# # Args:
# # primals: [indices, values] - the forward pass inputs
# # tangents: [indices_tangent, values_tangent] - the tangent vectors
# # output: Forward pass output (for reference)
# # Returns:
# # Output tangent
# # """
# # indices_tangent, values_tangent = tangents
# # # Indices tangents are ignored (indices are discrete)
# # # Apply the same scatter operation to values tangents
# # return scatter(self.target_shape, primals[0], values_tangent, axis=self.axis) # Use original indices
# # def forward(self, *args: Array) -> Array:
# # pass
# # def scatter(target_shape: tuple, indices: Array, values: Array, axis: int = -1) -> Array:
# # if axis >= 0:
# # # make negative
# # axis = axis - len(target_shape)
# # op = ScatterOp(target_shape, axis)
# # return op.forward(indices, values)
# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""View and shape manipulation operations."""
import numpy as np
from max.dtype import DType
from max.graph import TensorValue, ops
from ..core.array import Array, Shape
from .operation import Operation, ViewOperation
# Public API
__all__ = [
"transpose",
"permute",
"move_axis_to_front",
"move_axis_from_front",
"permute_batch_dims",
"move_axis_to_front_of_batch_dims",
"move_axis_from_front_of_batch_dims",
"reshape",
"broadcast_to",
"broadcast_batch_dims",
"squeeze",
"unsqueeze",
"squeeze_batch_dims",
"unsqueeze_batch_dims",
"shallow_copy",
"array_slice",
"pad",
"concatenate",
"stack",
]
class TransposeOp(ViewOperation):
"""Matrix/tensor transpose operation."""
def __init__(self, axis_1: int = -2, axis_2: int = -1):
super().__init__(f"transpose[permutation=({axis_1},{axis_2})]")
self.axis_1 = axis_1
self.axis_2 = axis_2
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compute output shape for transpose operation with compatible signature."""
if len(input_shapes) != 1:
raise ValueError(
f"Transpose operation requires 1 input shape, got {len(input_shapes)}"
)
arg_shape = input_shapes[0]
if not arg_shape:
# Transposing a scalar is a no-op.
return ()
# For rank 1, transpose is also a no-op.
if len(arg_shape) < 2:
return arg_shape
axis_1 = self.axis_1 if self.axis_1 >= 0 else len(arg_shape) + self.axis_1
axis_2 = self.axis_2 if self.axis_2 >= 0 else len(arg_shape) + self.axis_2
if axis_1 < 0 or axis_1 >= len(arg_shape):
raise ValueError(f"axis_1 {axis_1} is out of bounds for shape {arg_shape}")
if axis_2 < 0 or axis_2 >= len(arg_shape):
raise ValueError(f"axis_2 {axis_2} is out of bounds for shape {arg_shape}")
new_shape = list(arg_shape)
new_shape[axis_1], new_shape[axis_2] = new_shape[axis_2], new_shape[axis_1]
return tuple(new_shape)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
if len(args[0].shape) < 2:
output.tensor_value = args[0]
return
output.tensor_value = ops.transpose(args[0], self.axis_1, self.axis_2)
def eagerxpr(self, args: list[Array], output: Array) -> None:
if len(args[0].shape) < 2:
output._impl = args[0].impl
return
offset = len(args[0].batch_dims)
axes = list(range(-offset - len(args[0].shape), 0))
axes[self.axis_1], axes[self.axis_2] = axes[self.axis_2], axes[self.axis_1]
np_result = np.transpose(args[0].to_numpy(), axes)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
return [transpose(cotangent, self.axis_1, self.axis_2)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
return transpose(tangents[0], self.axis_1, self.axis_2)
[docs]
def transpose(arg: Array, axis_1: int = -2, axis_2: int = -1) -> Array:
"""Transpose array along two axes."""
if len(arg.shape) <= 1:
return arg
axis_1 = axis_1 if axis_1 < 0 else -len(arg.shape) + axis_1
axis_2 = axis_2 if axis_2 < 0 else -len(arg.shape) + axis_2
if axis_1 == axis_2:
return arg
if axis_1 < -len(arg.shape) or axis_2 < -len(arg.shape):
raise ValueError(
f"Invalid axes {axis_1}, {axis_2} for shape {arg.shape}. "
"Axes must be within the range of the array dimensions."
)
op = TransposeOp(axis_1, axis_2)
return op.forward(arg)
class TransposeBatchDimsOp(ViewOperation):
"""Transpose operation to swap two batch dimensions."""
def __init__(self, axis_1: int = -2, axis_2: int = -1):
"""Initialize transpose batch dims operation.
Args:
axis_1: First batch dimension axis to swap (negative indices preferred)
axis_2: Second batch dimension axis to swap (negative indices preferred)
"""
super().__init__(f"transpose_batch_dims[permutation=({axis_1},{axis_2})]")
self.axis_1 = axis_1
self.axis_2 = axis_2
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Shape stays the same for batch dimension operations."""
if len(input_shapes) != 1:
raise ValueError(
f"Transpose batch dims operation requires 1 input shape, got {len(input_shapes)}"
)
return input_shapes[0]
def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
"""Compute output batch_dims after transposing two axes."""
if len(input_batch_dimss) != 1:
raise ValueError(
f"Transpose batch dims operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
)
input_batch_dims = input_batch_dimss[0]
if not input_batch_dims:
raise ValueError(
"Cannot transpose batch dims of an array with no batch dimensions"
)
# Convert negative indices to positive for validation and computation
axis_1 = self.axis_1 + len(input_batch_dims) if self.axis_1 < 0 else self.axis_1
axis_2 = self.axis_2 + len(input_batch_dims) if self.axis_2 < 0 else self.axis_2
# Validate axes are within bounds
if axis_1 < 0 or axis_1 >= len(input_batch_dims):
raise ValueError(
f"axis_1 {self.axis_1} is out of bounds for batch_dims {input_batch_dims}"
)
if axis_2 < 0 or axis_2 >= len(input_batch_dims):
raise ValueError(
f"axis_2 {self.axis_2} is out of bounds for batch_dims {input_batch_dims}"
)
# Create new batch_dims with axes swapped
new_batch_dims = list(input_batch_dims)
new_batch_dims[axis_1], new_batch_dims[axis_2] = (
new_batch_dims[axis_2],
new_batch_dims[axis_1],
)
return tuple(new_batch_dims)
def forward(self, *args: Array) -> Array:
"""Override forward to handle single input."""
if len(args) != 1:
raise ValueError(
f"Transpose batch dims operation requires 1 argument, got {len(args)}"
)
return super().forward(*args)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
"""MAX graph implementation using ops.transpose."""
axis_1 = self.axis_1 - len(output.shape)
axis_2 = self.axis_2 - len(output.shape)
output.tensor_value = ops.transpose(args[0], axis_1, axis_2)
def eagerxpr(self, args: list[Array], output: Array) -> None:
"""Eager execution using NumPy transpose."""
input_array = args[0]
# Get the full tensor including batch dimensions
input_np = input_array.to_numpy()
axis_1 = self.axis_1 - len(args[0].shape)
axis_2 = self.axis_2 - len(args[0].shape)
# Create axes list for full transpose
total_dims = len(input_array.batch_dims) + len(input_array.shape)
axes = list(range(total_dims))
# Swap the two batch dimension axes
axes[axis_1], axes[axis_2] = axes[axis_2], axes[axis_1]
# Apply transpose
np_result = np.transpose(input_np, axes)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
"""VJP rule: transpose is its own inverse."""
return [transpose_batch_dims(cotangent, self.axis_1, self.axis_2)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
"""JVP rule: apply same transpose to tangents."""
return transpose_batch_dims(tangents[0], self.axis_1, self.axis_2)
def transpose_batch_dims(arg: Array, axis_1: int = -2, axis_2: int = -1) -> Array:
"""Transpose batch dimensions along two axes.
This operation swaps two axes in the batch_dims of an Array, similar to how
regular transpose works on shape dimensions. The shape dimensions remain unchanged.
Args:
arg: Input array with batch dimensions to transpose
axis_1: First batch dimension axis to swap (default: -2)
axis_2: Second batch dimension axis to swap (default: -1)
Returns:
Array with specified batch dimensions transposed
Example:
>>> import nabla as nb
>>> # Array with batch_dims=(2, 3, 4) and shape=(5, 6)
>>> x = nb.ones((5, 6))
>>> x.batch_dims = (2, 3, 4) # Simulated for example
>>> y = transpose_batch_dims(x, -3, -1) # Swap first and last batch dims
>>> # Result has batch_dims=(4, 3, 2) and shape=(5, 6)
"""
# Convert to negative indices for consistency with batch dimension handling
axis_1 = axis_1 if axis_1 < 0 else -len(arg.batch_dims) + axis_1
axis_2 = axis_2 if axis_2 < 0 else -len(arg.batch_dims) + axis_2
op = TransposeBatchDimsOp(axis_1, axis_2)
return op.forward(arg)
def compute_iterative_transpose_swaps(axes: tuple[int, ...]) -> list[tuple[int, int]]:
"""Compute the sequence of axis swaps needed to implement a permutation.
This function implements the algorithm from the Mojo code to determine what
sequence of axis swaps (transposes) are needed to achieve a given permutation.
Returns a list of (axis1, axis2) tuples that should be swapped in order.
Args:
axes: Tuple of axis indices specifying the desired permutation.
All indices should be negative (e.g., -1, -2, -3, ...)
Returns:
List of (axis1, axis2) tuples representing the swaps to perform in order.
Each tuple contains two negative axis indices to be swapped.
The algorithm works as follows:
1. Initialize current_axis_order as [-num_dims, -num_dims+1, ..., -1]
2. For each target position x, find where target_axis currently is (y)
3. If x != y, record the swap (x_neg, y_neg) and update current_axis_order
4. Return the list of all recorded swaps
"""
target_perm = list(axes)
num_dims = len(target_perm)
# Initialize current_axis_order as in Mojo: [-num_dims, -num_dims+1, ..., -1]
current_axis_order = []
for i in range(-num_dims, 0):
current_axis_order.append(i)
swaps = []
# For each target position x, move the correct axis there
for x in range(num_dims):
target_axis = target_perm[x]
# Find where target_axis currently is in the current ordering
try:
y = current_axis_order.index(target_axis)
except ValueError as e:
# target_axis not found in current_axis_order, this shouldn't happen
raise ValueError(
f"Target axis {target_axis} not found in current_axis_order {current_axis_order}"
) from e
# If already in the right position, skip
if x == y:
continue
# Convert to negative indices for the swap operation
x_neg = x - num_dims
y_neg = y - num_dims
# Record the swap
swaps.append((x_neg, y_neg))
# Update current_axis_order to reflect the swap
# Swap the elements at positions x and y
current_axis_order[x], current_axis_order[y] = (
current_axis_order[y],
current_axis_order[x],
)
return swaps
class PermuteOp(ViewOperation):
"""Permute (reorder) the dimensions of a tensor according to given axes."""
def __init__(self, axes: tuple[int, ...]):
"""Initialize permute operation.
Args:
axes: Tuple of axis indices specifying the new order.
Must be a permutation of range(ndim).
"""
super().__init__(f"permute[axes={axes}]")
self.axes = tuple(axes)
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compute output shape after permutation."""
if len(input_shapes) != 1:
raise ValueError(
f"Permute operation requires 1 input shape, got {len(input_shapes)}"
)
input_shape = input_shapes[0]
# Validate axes - should now be all negative and same length as input
if len(self.axes) != len(input_shape):
raise ValueError(
f"Number of axes {len(self.axes)} must match input dimensions {len(input_shape)}"
)
# Convert to positive indices for validation
positive_axes = [ax + len(input_shape) for ax in self.axes]
if sorted(positive_axes) != list(range(len(input_shape))):
raise ValueError(
f"Axes {self.axes} must be a permutation of negative indices corresponding to {list(range(len(input_shape)))}"
)
# Reorder dimensions according to axes (convert negative to positive for indexing)
return tuple(input_shape[axis + len(input_shape)] for axis in self.axes)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
"""Max computation: permute the tensor using iterative transpose."""
# Get the sequence of swaps needed for this permutation
swaps = compute_iterative_transpose_swaps(self.axes)
# Apply each swap in sequence
out_symbol = args[0]
for axis1, axis2 in swaps:
out_symbol = ops.transpose(out_symbol, axis1, axis2)
output.tensor_value = out_symbol
def eagerxpr(self, args: list[Array], output: Array) -> None:
"""Eager computation: permute using numpy."""
# Handle batch dimensions properly like transpose does
offset = len(args[0].batch_dims)
# Convert our negative axes (relative to array shape) to work with full numpy array
numpy_axes = []
for ax in self.axes:
# ax is negative relative to args[0].shape, convert to positive
array_pos_ax = ax + len(args[0].shape)
# Now convert to position in full numpy array (including batch dims)
numpy_pos_ax = offset + array_pos_ax
numpy_axes.append(numpy_pos_ax)
# Prepend batch dimension indices (they stay in their original positions)
full_axes = list(range(offset)) + numpy_axes
np_result = np.transpose(args[0].to_numpy(), full_axes)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
"""VJP rule: reverse the permutation."""
# Create inverse permutation for negative indices
inv_axes = [0] * len(self.axes)
for i, axis in enumerate(self.axes):
# Convert negative axis to positive index for inverse mapping
pos_axis = axis + len(self.axes)
inv_axes[pos_axis] = i
# Convert back to negative indices
inv_axes_negative = [-len(self.axes) + ax for ax in inv_axes]
return [permute(cotangent, tuple(inv_axes_negative))]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
"""JVP rule: apply same permutation to tangent."""
return permute(tangents[0], self.axes)
[docs]
def permute(input_array: Array, axes: tuple[int, ...]) -> Array:
"""Permute (reorder) the dimensions of a tensor.
Args:
input_array: Input tensor
axes: Tuple specifying the new order of dimensions
Returns:
Tensor with reordered dimensions
Example:
>>> x = nb.ones((2, 3, 4)) # shape (2, 3, 4)
>>> y = permute(x, (2, 0, 1)) # shape (4, 2, 3)
>>> # Dimension 2 -> position 0, dimension 0 -> position 1, dimension 1 -> position 2
"""
# always store axes to be fully negative
axes = tuple(-len(input_array.shape) + ax if ax >= 0 else ax for ax in axes)
# but first we add oentailly missing axes which we treat as unpemruted
axes_new = []
for i in range(-len(input_array.shape), -len(axes)):
axes_new.append(i)
axes = tuple(axes_new + list(axes)) # prepend missing axes to the front
op = PermuteOp(axes)
return op.forward(input_array)
class PermuteBatchDimsOp(ViewOperation):
"""Permute (reorder) the batch dimensions of an array according to given axes."""
def __init__(self, axes: tuple[int, ...]):
"""Initialize permute batch dims operation.
Args:
axes: Tuple of axis indices specifying the new order for batch_dims.
Must be a permutation of range(-len(batch_dims), 0).
All indices should be negative.
"""
super().__init__(f"permute_batch_dims[axes={axes}]")
self.axes = tuple(axes)
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Shape stays the same for batch dimension operations."""
if len(input_shapes) != 1:
raise ValueError(
f"Permute batch dims operation requires 1 input shape, got {len(input_shapes)}"
)
return input_shapes[0]
def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
"""Compute output batch_dims after permutation."""
if len(input_batch_dimss) != 1:
raise ValueError(
f"Permute batch dims operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
)
input_batch_dims = input_batch_dimss[0]
if not input_batch_dims:
raise ValueError(
"Cannot permute batch dims of an array with no batch dimensions"
)
# Validate axes - should be all negative and same length as input batch_dims
if len(self.axes) != len(input_batch_dims):
raise ValueError(
f"Number of axes {len(self.axes)} must match input batch dimensions {len(input_batch_dims)}"
)
# Convert to positive indices for validation
positive_axes = [ax + len(input_batch_dims) for ax in self.axes]
if sorted(positive_axes) != list(range(len(input_batch_dims))):
raise ValueError(
f"Axes {self.axes} must be a permutation of negative indices corresponding to batch_dims range"
)
# Reorder batch dimensions according to axes (convert negative to positive for indexing)
return tuple(
input_batch_dims[axis + len(input_batch_dims)] for axis in self.axes
)
def forward(self, *args: Array) -> Array:
"""Override forward to handle single input."""
if len(args) != 1:
raise ValueError(
f"Permute batch dims operation requires 1 argument, got {len(args)}"
)
return super().forward(*args)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
"""MAX graph implementation using ops.transpose."""
# Get the sequence of swaps needed for this permutation
swaps = compute_iterative_transpose_swaps(self.axes)
swaps = [
(axis1 - len(output.shape), axis2 - len(output.shape))
for axis1, axis2 in swaps
]
# Apply each swap in sequence
out_symbol = args[0]
for axis1, axis2 in swaps:
out_symbol = ops.transpose(out_symbol, axis1, axis2)
output.tensor_value = out_symbol
def eagerxpr(self, args: list[Array], output: Array) -> None:
"""Eager execution using NumPy transpose."""
input_array = args[0]
# Get the full tensor including batch dimensions
input_np = input_array.to_numpy()
# Convert batch dimension axes to full tensor indices
# Following the pattern from other batch operations
numpy_axes = []
for ax in self.axes:
# ax is negative relative to batch_dims, convert to full tensor position
batch_pos_ax = ax - len(input_array.shape)
numpy_axes.append(batch_pos_ax)
# Add shape dimension indices (they stay in their original relative positions)
# They come after the batch dimensions in the permuted tensor
shape_offset = len(input_array.batch_dims)
shape_axes = list(range(shape_offset, shape_offset + len(input_array.shape)))
full_axes = numpy_axes + shape_axes
# Apply transpose
np_result = np.transpose(input_np, full_axes)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
# """VJP rule: reverse the permutation."""
# Create inverse permutation for negative indices
inv_axes = [0] * len(self.axes)
for i, axis in enumerate(self.axes):
# Convert negative axis to positive index for inverse mapping
pos_axis = axis + len(self.axes)
inv_axes[pos_axis] = i
# Convert back to negative indices
inv_axes_negative = [-len(self.axes) + ax for ax in inv_axes]
return [permute_batch_dims(cotangent, tuple(inv_axes_negative))]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
"""JVP rule: apply same permutation to tangent."""
return permute_batch_dims(tangents[0], self.axes)
[docs]
def permute_batch_dims(input_array: Array, axes: tuple[int, ...]) -> Array:
"""Permute (reorder) the batch dimensions of an array.
This operation reorders the batch_dims of an Array according to the given axes,
similar to how regular permute works on shape dimensions. The shape dimensions
remain unchanged.
Args:
input_array: Input array with batch dimensions to permute
axes: Tuple specifying the new order of batch dimensions.
All indices should be negative and form a permutation.
Returns:
Array with batch dimensions reordered according to axes
Example:
>>> import nabla as nb
>>> # Array with batch_dims=(2, 3, 4) and shape=(5, 6)
>>> x = nb.ones((5, 6))
>>> x.batch_dims = (2, 3, 4) # Simulated for example
>>> y = permute_batch_dims(x, (-1, -3, -2)) # Reorder as (4, 2, 3)
>>> # Result has batch_dims=(4, 2, 3) and shape=(5, 6)
"""
if len(axes) <= 1:
return input_array # No permutation needed for single axis or empty
# Convert to negative indices for consistency with batch dimension handling
axes = tuple(-len(input_array.batch_dims) + ax if ax >= 0 else ax for ax in axes)
# Handle case where fewer axes are provided - prepend missing axes to front
if len(axes) < len(input_array.batch_dims):
axes_new = []
for i in range(-len(input_array.batch_dims), -len(axes)):
axes_new.append(i)
axes = tuple(axes_new) + axes
op = PermuteBatchDimsOp(axes)
return op.forward(input_array)
[docs]
def move_axis_to_front(input_array: Array, axis: int) -> Array:
"""Move specified axis to the front (position 0), shifting others right.
Args:
input_array: Input tensor
axis: Axis to move to front
Returns:
Tensor with specified axis moved to front
Example:
>>> x = nb.ones((2, 3, 4)) # shape (2, 3, 4)
>>> y = move_axis_to_front(x, 2) # shape (4, 2, 3)
>>> # axis 2 moved to front, others shifted: [2, 0, 1]
"""
ndim = len(input_array.shape)
# Normalize negative axis
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Axis {axis} out of bounds for array of dimension {ndim}")
# Generate permutation: [axis, 0, 1, ..., axis-1, axis+1, ..., ndim-1]
axes = [axis] + [i for i in range(ndim) if i != axis]
return permute(input_array, tuple(axes))
def move_axis_to_back(input_array: Array, axis: int) -> Array:
"""Move specified axis to the back (last position), shifting others left.
Args:
input_array: Input tensor
axis: Axis to move to back
Returns:
Tensor with specified axis moved to back
Example:
>>> x = nb.ones((2, 3, 4)) # shape (2, 3, 4)
>>> y = move_axis_to_back(x, 0) # shape (3, 4, 2)
>>> # axis 0 moved to back, others shifted: [1, 2, 0]
"""
ndim = len(input_array.shape)
# Normalize negative axis
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Axis {axis} out of bounds for array of dimension {ndim}")
# Generate permutation: [0, 1, ..., axis-1, axis+1, ..., ndim-1, axis]
axes = [i for i in range(ndim) if i != axis] + [axis]
return permute(input_array, tuple(axes))
[docs]
def move_axis_from_front(input_array: Array, target_axis: int) -> Array:
"""Move front axis (position 0) to specified target position.
Args:
input_array: Input tensor (assumes front axis is the one to move)
target_axis: Target position for the front axis
Returns:
Tensor with front axis moved to target position
Example:
>>> x = nb.ones((4, 2, 3)) # front axis has size 4
>>> y = move_axis_from_front(x, 2) # shape (2, 3, 4)
>>> # front axis moved to position 2: [1, 2, 0]
"""
ndim = len(input_array.shape)
# Normalize negative axis
if target_axis < 0:
target_axis = ndim + target_axis
if target_axis < 0 or target_axis >= ndim:
raise ValueError(
f"Target axis {target_axis} out of bounds for array of dimension {ndim}"
)
if target_axis == 0:
return input_array # Already at front
# Generate permutation to move front to target_axis
# [1, 2, ..., target_axis, 0, target_axis+1, ..., ndim-1]
axes = list(range(1, target_axis + 1)) + [0] + list(range(target_axis + 1, ndim))
return permute(input_array, tuple(axes))
def move_axis_from_back(input_array: Array, target_axis: int) -> Array:
"""Move back axis (last position) to specified target position.
Args:
input_array: Input tensor (assumes back axis is the one to move)
target_axis: Target position for the back axis
Returns:
Tensor with back axis moved to target position
Example:
>>> x = nb.ones((4, 2, 3)) # back axis has size 3
>>> y = move_axis_from_back(x, 1) # shape (2, 4, 3)
>>> # back axis moved to position 1: [0, 2, 1]
"""
ndim = len(input_array.shape)
# Normalize negative axis
if target_axis < 0:
target_axis = ndim + target_axis
if target_axis < 0 or target_axis >= ndim:
raise ValueError(
f"Target axis {target_axis} out of bounds for array of dimension {ndim}"
)
if target_axis == ndim - 1:
return input_array # Already at back
# Generate permutation to move back to target_axis
axes = list(range(0, target_axis)) + [ndim - 1] + list(range(target_axis, ndim - 1))
return permute(input_array, tuple(axes))
[docs]
def move_axis_to_front_of_batch_dims(input_array: Array, axis: int) -> Array:
"""Move specified batch dimension to the front (position 0), shifting others right.
Args:
input_array: Input tensor with batch dimensions
axis: Batch dimension to move to front (negative index)
Returns:
Tensor with specified batch dimension moved to front
Example:
>>> x = nb.ones((2, 3, 4)) # shape (2, 3, 4)
>>> x.batch_dims = (1, 0) # Simulated for example
>>> y = move_axis_to_fron_of_batch_dims(x, -1) # Move last batch dim to front
>>> # Result has batch_dims=(0, 1) and shape=(2, 3, 4)
"""
ndim = len(input_array.batch_dims)
# Normalize negative axis
if axis >= 0:
axis = -len(input_array.batch_dims) + axis
if axis < -len(input_array.batch_dims) or axis >= 0:
raise ValueError(
f"Axis {axis} out of bounds for batch_dims of dimension {ndim}"
)
# Generate permutation: [axis, 0, 1, ..., axis-1, axis+1, ..., ndim-1]
axes = [axis] + [i for i in range(-len(input_array.batch_dims), 0) if i != axis]
return permute_batch_dims(input_array, tuple(axes))
[docs]
def move_axis_from_front_of_batch_dims(input_array: Array, target_axis: int) -> Array:
"""Move front batch dimension (position 0) to specified target position.
Args:
input_array: Input tensor with batch dimensions (assumes front batch dim is the one to move)
target_axis: Target position for the front batch dimension (negative index)
Returns:
Tensor with front batch dimension moved to target position
Example:
>>> x = nb.ones((4, 2, 3)) # shape (4, 2, 3)
>>> x.batch_dims = (0, 1) # Simulated for example
>>> y = move_axis_from_front_of_batch_dims(x, -1) # Move front batch dim to last position
>>> # Result has batch_dims=(1, 0) and shape=(4, 2, 3)
"""
ndim = len(input_array.batch_dims)
# Normalize negative axis
if target_axis >= 0:
target_axis = -len(input_array.batch_dims) + target_axis
if target_axis < -len(input_array.batch_dims) or target_axis >= 0:
raise ValueError(
f"Target axis {target_axis} out of bounds for batch_dims of dimension {ndim}"
)
if target_axis == 0:
return input_array # Already at front
# Generate permutation to move front to target_axis
axes = (
list(range(-len(input_array.batch_dims) + 1, target_axis + 1))
+ [0]
+ list(range(target_axis + 1, 0))
)
return permute_batch_dims(input_array, tuple(axes))
class ReshapeOp(ViewOperation):
"""Reshape operation."""
def __init__(self, arg_shape: Shape, target_shape: Shape):
super().__init__(f"reshape[new_sizes={target_shape}]")
self.arg_shape = arg_shape
self.target_shape = target_shape
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compatible signature."""
if len(input_shapes) != 1:
raise ValueError(
f"Reshape operation requires 1 input shape, got {len(input_shapes)}"
)
return self.target_shape
def forward(self, *args: Array) -> Array:
"""Override forward to validate size compatibility with compatible signature."""
if len(args) != 1:
raise ValueError(f"Reshape operation requires 1 argument, got {len(args)}")
arg = args[0]
old_size = np.prod(arg.shape) if arg.shape else 1
new_size = np.prod(self.target_shape) if self.target_shape else 1
if old_size != new_size:
raise ValueError(
f"Cannot reshape array of size {old_size} to shape {self.target_shape} of size {new_size}"
)
return super().forward(arg)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
output.tensor_value = ops.reshape(
args[0], output.batch_dims + self.target_shape
)
def eagerxpr(self, args: list[Array], output: Array) -> None:
np_result = np.reshape(
args[0].to_numpy(), output.batch_dims + self.target_shape
)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
return [reshape(cotangent, self.arg_shape)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
return reshape(tangents[0], self.target_shape)
[docs]
def reshape(arg: Array, shape: Shape) -> Array:
"""Reshape array to given shape."""
# Handle -1 dimension inference
if -1 in shape:
# Compute the inferred dimension
total_size = np.prod(arg.shape) if arg.shape else 1
known_size = 1
unknown_idx = -1
for i, dim in enumerate(shape):
if dim == -1:
if unknown_idx != -1:
raise ValueError("Can only specify one unknown dimension with -1")
unknown_idx = i
else:
known_size *= dim
if unknown_idx == -1:
# No -1 found, use shape as is
target_shape = shape
else:
# Calculate the unknown dimension
if known_size == 0:
raise ValueError(
"Cannot infer shape when known dimensions have zero size"
)
if total_size % known_size != 0:
raise ValueError(
f"Cannot reshape array of size {total_size} to shape {shape}"
)
inferred_dim = total_size // known_size
target_shape = tuple(
int(inferred_dim if dim == -1 else dim) for dim in shape
)
else:
target_shape = tuple(int(dim) for dim in shape)
op = ReshapeOp(arg.shape, target_shape)
return op.forward(arg)
class BroadcastToOp(ViewOperation):
"""Broadcast array to target shape."""
def __init__(self, target_shape: Shape):
super().__init__(f"broadcast[shape={target_shape}]")
self.target_shape = target_shape
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compatible signature."""
if len(input_shapes) != 1:
raise ValueError(
f"Broadcast operation requires 1 input shape, got {len(input_shapes)}"
)
return self.target_shape
def forward(self, *args: Array) -> Array:
"""Override forward to handle case where no broadcasting needed with compatible signature."""
if len(args) != 1:
raise ValueError(
f"Broadcast operation requires 1 argument, got {len(args)}"
)
arg = args[0]
if arg.shape == self.target_shape:
return arg
return super().forward(*args)
@staticmethod
def get_broadcasted_axes(input_shape: Shape, target_shape: Shape) -> list[int]:
"""Get axes that were broadcasted (for VJP)."""
if len(input_shape) > len(target_shape):
raise ValueError(
f"Input shape {input_shape} cannot be broadcast to {target_shape}"
)
broadcasted_axes = []
padded_input = (1,) * (len(target_shape) - len(input_shape)) + input_shape
for i in range(len(target_shape)):
if padded_input[i] == 1 and target_shape[i] > 1:
# Return negative index to reference from the right side
# This ensures we sum over the correct dimension
broadcasted_axes.append(i - len(target_shape))
elif padded_input[i] != target_shape[i] and padded_input[i] != 1:
raise ValueError(f"Cannot broadcast {input_shape} to {target_shape}")
return broadcasted_axes
def maxpr(self, args: list[TensorValue], output: Array) -> None:
output.tensor_value = ops.broadcast_to(
args[0], output.batch_dims + self.target_shape
)
def eagerxpr(self, args: list[Array], output: Array) -> None:
np_result = np.broadcast_to(
args[0].to_numpy(), shape=output.batch_dims + self.target_shape
)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
broadcasted_axes = self.get_broadcasted_axes(
primals[0].shape, self.target_shape
)
from .reduce import sum as sum_op # Renamed to avoid shadowing built-in
return [sum_op(cotangent, axes=broadcasted_axes, keep_dims=True)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
return broadcast_to(tangents[0], self.target_shape)
[docs]
def broadcast_to(arg: Array, shape: Shape) -> Array:
"""Broadcast array to target shape."""
if arg.shape == shape:
return arg
for _ in range(len(shape) - len(arg.shape)):
arg = unsqueeze(arg, [0])
op = BroadcastToOp(shape)
return op.forward(arg)
class BroadcastBatchDimsOp(ViewOperation):
"""Broadcast array to target batch_dims."""
def __init__(self, target_batch_dims: Shape):
super().__init__(f"broadcast_batch_dims[shape={target_batch_dims}]")
self.target_batch_dims = target_batch_dims
def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
"""Compatible signature."""
if len(input_batch_dimss) != 1:
raise ValueError(
f"Broadcast operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
)
return self.target_batch_dims
def forward(self, *args: Array) -> Array:
"""Override forward to handle case where no broadcasting needed with compatible signature."""
if len(args) != 1:
raise ValueError(
f"Broadcast operation requires 1 argument, got {len(args)}"
)
arg = args[0]
if arg.batch_dims == self.target_batch_dims:
return arg
return super().forward(*args)
@staticmethod
def get_broadcasted_axes(
input_batch_dims: Shape, target_batch_dims: Shape
) -> list[int]:
"""Get axes that were broadcasted (for VJP)."""
if len(input_batch_dims) > len(target_batch_dims):
raise ValueError(
f"Input batch_dims {input_batch_dims} cannot be broadcast to {target_batch_dims}"
)
broadcasted_axes = []
padded_input = (1,) * (
len(target_batch_dims) - len(input_batch_dims)
) + input_batch_dims
for i in range(len(target_batch_dims)):
if padded_input[i] == 1 and i < len(target_batch_dims) - len(
input_batch_dims
):
# This dimension was added by padding (broadcasted from non-existent to size 1 or more)
broadcasted_axes.append(i - len(target_batch_dims))
elif padded_input[i] == 1 and target_batch_dims[i] > 1:
# This dimension was broadcasted from size 1 to larger size
broadcasted_axes.append(i - len(target_batch_dims))
elif padded_input[i] != target_batch_dims[i] and padded_input[i] != 1:
raise ValueError(
f"Cannot broadcast {input_batch_dims} to {target_batch_dims}"
)
return broadcasted_axes
def maxpr(self, args: list[TensorValue], output: Array) -> None:
output.tensor_value = ops.broadcast_to(
args[0], self.target_batch_dims + output.shape
)
def eagerxpr(self, args: list[Array], output: Array) -> None:
np_result = np.broadcast_to(
args[0].to_numpy(), shape=self.target_batch_dims + output.shape
)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
from .reduce import sum_batch_dims
broadcasted_axes = self.get_broadcasted_axes(
primals[0].batch_dims, output.batch_dims
)
return [sum_batch_dims(cotangent, axes=broadcasted_axes, keep_dims=True)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
return broadcast_batch_dims(tangents[0], self.target_batch_dims)
[docs]
def broadcast_batch_dims(arg: Array, batch_dims: Shape) -> Array:
"""Broadcast array to target batch_dims."""
if arg.batch_dims == batch_dims:
return arg
for _ in range(len(batch_dims) - len(arg.batch_dims)):
arg = unsqueeze_batch_dims(arg, [0])
op = BroadcastBatchDimsOp(batch_dims)
return op.forward(arg)
class SqueezeOp(ViewOperation):
"""Squeeze operation to remove dimensions of size 1."""
def __init__(self, axes: list[int] | None = None):
super().__init__(f"squeeze[axes={axes}]")
self.axes = sorted(axes) if axes is not None else []
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compatible signature."""
if len(input_shapes) != 1:
raise ValueError(
f"Squeeze operation requires 1 input shape, got {len(input_shapes)}"
)
input_shape = input_shapes[0]
new_shape = list(input_shape)
for ax in self.axes:
if ax < -len(new_shape) or ax >= len(new_shape):
raise ValueError(f"Axis {ax} is out of bounds for squeeze operation")
if input_shape[ax] == 1:
new_shape[ax] = None
else:
raise ValueError(
f"Cannot squeeze axis {ax} of size {input_shape[ax]} (must be 1)"
)
new_shape = [dim for dim in new_shape if dim is not None]
return tuple(new_shape)
def forward(self, *args: Array) -> Array:
"""Override forward to handle case where no squeezing needed with compatible signature."""
if len(args) != 1:
raise ValueError(f"Squeeze operation requires 1 argument, got {len(args)}")
return super().forward(*args)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
res = args[0]
# Use self.axes directly since it's already normalized to a list in __init__
for ax in self.axes:
res = ops.squeeze(res, ax)
output.tensor_value = res
def eagerxpr(self, args: list[Array], output: Array) -> None:
axis = tuple(self.axes) if self.axes else None
np_result = np.squeeze(args[0].to_numpy(), axis=axis)
output.impl_(np_result)
def vjp_rule(
self, _primals: list[Array], cotangent: Array, _output: Array
) -> list[Array]:
return [unsqueeze(cotangent, self.axes)]
def jvp_rule(
self, _primals: list[Array], tangents: list[Array], _output: Array
) -> Array:
return squeeze(tangents[0], self.axes)
[docs]
def squeeze(arg: Array, axes: list[int] | None = None) -> Array:
"""Squeeze array by removing dimensions of size 1."""
if axes is None:
return arg
axes = [ax if ax < 0 else -len(arg.shape) + ax for ax in axes]
op = SqueezeOp(axes)
res = op.forward(arg)
return res
class UnsqueezeOp(ViewOperation):
"""Unsqueeze operation to add dimensions of size 1."""
def __init__(self, axes: list[int] | None = None):
super().__init__(f"unsqueeze[axes={axes}]")
self.axes = sorted(axes) if axes is not None else []
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compatible signature."""
if len(input_shapes) != 1:
raise ValueError(
f"Unsqueeze operation requires 1 input shape, got {len(input_shapes)}"
)
input_shape = input_shapes[0]
new_shape = list(input_shape)
for ax in self.axes:
if ax < -len(new_shape) - 1:
raise ValueError(f"Axis {ax} is out of bounds for unsqueeze operation")
if ax + 1 <= -1:
new_shape.insert(ax + 1, 1)
else:
new_shape.append(1)
return tuple(new_shape)
def forward(self, *args: Array) -> Array:
"""Override forward to handle case where no unsqueezing needed with compatible signature."""
if len(args) != 1:
raise ValueError(
f"Unsqueeze operation requires 1 argument, got {len(args)}"
)
return super().forward(*args)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
res_value = args[0]
for ax in self.axes:
res_value = ops.unsqueeze(res_value, ax)
output.tensor_value = res_value
def eagerxpr(self, args: list[Array], output: Array) -> None:
np_result = np.expand_dims(args[0].to_numpy(), axis=self.axes)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
return [squeeze(cotangent, self.axes)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
return unsqueeze(tangents[0], self.axes)
[docs]
def unsqueeze(arg: Array, axes: list[int] | None = None) -> Array:
"""Unsqueeze array by adding dimensions of size 1."""
if axes is None:
return arg
axes = [ax if ax < 0 else -len(arg.shape) - 1 + ax for ax in axes]
op = UnsqueezeOp(axes)
return op.forward(arg)
class ShallowCopyOp(ViewOperation):
"""Copy operation to create a new array with the same data."""
def __init__(self, arg: Array):
# if not arg.name and arg._impl and arg.shape == () and arg.batch_dims == ():
# name = arg.to_numpy().__repr__() # Use numpy repr for empty arrays
# else:
name = "shallow_copy"
super().__init__(name)
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compatible signature."""
if len(input_shapes) != 1:
raise ValueError(
f"Copy operation requires 1 input shape, got {len(input_shapes)}"
)
return input_shapes[0]
def maxpr(self, args: list[TensorValue], output: Array) -> None:
output.tensor_value = args[0]
def eagerxpr(self, args: list[Array], output: Array) -> None:
output.impl_(args[0]._impl)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
return [cotangent]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
return tangents[0]
[docs]
def shallow_copy(arg: Array) -> Array:
"""Create a shallow copy of the array."""
op = ShallowCopyOp(arg)
return op.forward(arg)
class ConcatenateOp(Operation):
"""Concatenate operation to join arrays along an existing axis."""
def __init__(self, axis: int = 0):
super().__init__(f"concatenate[axis={axis}]")
self.axis = axis
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compute output shape for concatenate operation."""
if len(input_shapes) == 0:
raise ValueError("Concatenate operation requires at least 1 input")
# All input shapes must be the same except along the concatenation axis
first_shape = input_shapes[0]
if not first_shape:
raise ValueError("Cannot concatenate empty shapes")
# Normalize axis
axis = self.axis if self.axis >= 0 else len(first_shape) + self.axis
if axis < 0 or axis >= len(first_shape):
raise ValueError(
f"Axis {self.axis} is out of bounds for array with {len(first_shape)} dimensions"
)
# Check that all shapes are compatible
total_size_along_axis = 0
for i, shape in enumerate(input_shapes):
if len(shape) != len(first_shape):
raise ValueError(
f"All inputs must have the same number of dimensions for concatenate operation. "
f"Input 0 has {len(first_shape)} dimensions, input {i} has {len(shape)} dimensions"
)
for j, (dim1, dim2) in enumerate(zip(first_shape, shape, strict=False)):
if j != axis and dim1 != dim2:
raise ValueError(
f"All inputs must have the same shape except along axis {axis}. "
f"Input 0 has shape {first_shape}, input {i} has shape {shape}"
)
total_size_along_axis += shape[axis]
# Compute output shape
output_shape = list(first_shape)
output_shape[axis] = total_size_along_axis
return tuple(output_shape)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
"""MAX graph implementation using ops.concat."""
# Normalize axis for MAX operations, considering batch_dims
# full_output_shape = output.batch_dims + output.shape # TODO: Use if needed
axis = self.axis if self.axis >= 0 else len(output.shape) + self.axis
# Adjust axis to account for batch_dims in the actual tensor
axis_in_tensor = axis + len(output.batch_dims)
output.tensor_value = ops.concat(args, axis=axis_in_tensor)
def eagerxpr(self, args: list[Array], output: Array) -> None:
"""Eager execution using NumPy concatenate."""
import numpy as np
numpy_arrays = [arg.to_numpy() for arg in args]
# Normalize axis for NumPy operations, considering batch_dims
axis = self.axis if self.axis >= 0 else len(output.shape) + self.axis
# Adjust axis to account for batch_dims in the actual tensor
axis_in_tensor = axis + len(output.batch_dims)
result = np.concatenate(numpy_arrays, axis=axis_in_tensor)
output.impl_(result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
"""Vector-Jacobian product rule for concatenate operation.
The VJP of concatenate is slicing the cotangent back into pieces.
"""
# Normalize axis
axis = self.axis if self.axis >= 0 else len(cotangent.shape) + self.axis
# Split the cotangent along the concatenated axis
result = []
start_idx = 0
for primal in primals:
size_along_axis = primal.shape[axis]
end_idx = start_idx + size_along_axis
# Create slice that selects this input's portion along the concatenated axis
slices = [slice(None)] * len(cotangent.shape)
slices[axis] = slice(start_idx, end_idx)
# Slice the cotangent
sliced = array_slice(cotangent, slices)
result.append(sliced)
start_idx = end_idx
return result
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
"""Jacobian-vector product rule for concatenate operation.
The JVP of concatenate is concatenating the tangents along the same axis.
"""
# Use the ConcatenateOp directly to avoid circular import
op = ConcatenateOp(axis=self.axis)
return op.forward(*tangents)
def forward(self, *args: Array) -> Array:
"""Forward pass for concatenate operation with multiple inputs."""
if len(args) == 0:
raise ValueError("Concatenate operation requires at least 1 argument")
# Move arrays to best device
from .operation import move_to_best_device
args = move_to_best_device(*args)
# Validate inputs have compatible properties
first_arg = args[0]
for _i, arg in enumerate(args[1:], 1):
if arg.dtype != first_arg.dtype:
raise ValueError(
f"All inputs must have the same dtype. Got {arg.dtype} vs {first_arg.dtype}"
)
if arg.device != first_arg.device:
raise ValueError(
f"All inputs must be on the same device. Got {arg.device} vs {first_arg.device}"
)
# Compute output properties
input_shapes = [arg.shape for arg in args]
output_shape = self.compute_output_shape(*input_shapes)
# All inputs should have the same batch_dims
output_batch_dims = first_arg.batch_dims
for i, arg in enumerate(args[1:], 1):
if arg.batch_dims != output_batch_dims:
raise ValueError(
f"All inputs must have the same batch_dims for concatenate operation. "
f"Input 0 has batch_dims {output_batch_dims}, input {i} has batch_dims {arg.batch_dims}"
)
# Create result array
res = Array(
shape=output_shape,
dtype=first_arg.dtype,
device=first_arg.device,
materialize=False,
name=self.name,
batch_dims=output_batch_dims,
)
# Set up computation
res.set_maxpr(self.maxpr)
res.add_arguments(*args)
res.vjp_rule = self.vjp_rule
res.jvp_rule = self.jvp_rule
# Execute eager computation if needed
if not res.stage_realization:
self.eagerxpr(list(args), res)
return res
[docs]
def concatenate(args: list[Array], axis: int = 0) -> Array:
"""Concatenate arrays along an existing axis.
Args:
args: List of arrays to concatenate
axis: Axis along which to concatenate arrays (default: 0)
Returns:
Concatenated array
"""
if not args:
raise ValueError("Concatenate operation requires at least one array")
op = ConcatenateOp(axis)
return op.forward(*args)
class ArraySliceOp(ViewOperation):
"""Array slicing operation."""
def __init__(self, slices: list[slice], squeeze_axes: list[int] | None = None):
# Store original slices for reference
self.original_slices = slices.copy()
# Check if we have negative steps - if so, we'll need special handling
self.has_negative_steps = any(s.step is not None and s.step < 0 for s in slices)
# Convert slices to a more manageable format
slice_strs = []
for s in slices:
start = s.start if s.start is not None else ""
stop = s.stop if s.stop is not None else ""
step = s.step if s.step is not None else ""
if step and step != 1:
slice_strs.append(f"{start}:{stop}:{step}")
else:
slice_strs.append(f"{start}:{stop}")
squeeze_info = f"_squeeze{squeeze_axes}" if squeeze_axes else ""
super().__init__(f"array_slice[{','.join(slice_strs)}]{squeeze_info}")
self.slices = slices
self.squeeze_axes = squeeze_axes or []
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compute output shape for array slice operation."""
if len(input_shapes) != 1:
raise ValueError(
f"Array slice operation requires 1 input shape, got {len(input_shapes)}"
)
input_shape = input_shapes[0]
output_shape = []
if len(self.slices) > len(input_shape):
raise IndexError(
f"too many indices for array: array is {len(input_shape)}-dimensional, but {len(self.slices)} were indexed"
)
# Process each dimension
for i, dim_size in enumerate(input_shape):
if i < len(self.slices):
s = self.slices[i]
start = s.start if s.start is not None else 0
stop = s.stop if s.stop is not None else dim_size
step = s.step if s.step is not None else 1
# Handle negative indices
if start < 0:
start = max(0, dim_size + start)
if stop < 0:
stop = max(0, dim_size + stop)
# Clamp to valid range
start = max(0, min(start, dim_size))
stop = max(start, min(stop, dim_size))
# Calculate output size for this dimension
if step > 0:
output_size = max(0, (stop - start + step - 1) // step)
elif step < 0:
# Handle negative step - reverse direction
# For negative step, we need start > stop (conceptually)
# But we need to handle the actual range calculation
if start >= stop:
# For negative step with start >= stop, we go from start down to stop+1
output_size = max(0, (start - stop + (-step) - 1) // (-step))
else:
# Invalid range for negative step
output_size = 0
else:
raise ValueError("Step cannot be zero")
# Skip this dimension if it should be squeezed (JAX-compatible behavior)
if i not in self.squeeze_axes:
output_shape.append(output_size)
else:
# No slice for this dimension, keep original size
output_shape.append(dim_size)
return tuple(output_shape)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
"""MAX graph implementation using ops.slice_tensor."""
input_tensor = args[0]
# Check for negative steps - not supported in JIT mode yet
if self.has_negative_steps:
raise NotImplementedError(
"Negative step slicing (e.g., [::-1]) is not yet supported in JIT-compiled functions "
"due to MAX engine limitations. Use eager execution instead, or avoid negative steps "
"in JIT-compiled code."
)
# Build slice indices for MAX ops.slice_tensor
slice_indices = []
# Add full slices for batch dimensions
for _ in range(len(output.batch_dims)):
slice_indices.append(slice(None))
# Add the user-provided slices
slice_indices.extend(self.slices)
# Pad with full slices up to the total rank of the input tensor.
num_physical_dims = len(input_tensor.shape)
while len(slice_indices) < num_physical_dims:
slice_indices.append(slice(None))
# Truncate if too long (can happen in weird vmap cases)
slice_indices = slice_indices[:num_physical_dims]
# Apply the slicing
result = ops.slice_tensor(input_tensor, slice_indices)
# Apply squeezing for JAX-compatible behavior
if self.squeeze_axes:
# Adjust squeeze axes to account for batch dimensions
squeeze_axes_adjusted = [
ax + len(output.batch_dims) for ax in self.squeeze_axes
]
for ax in sorted(
squeeze_axes_adjusted, reverse=True
): # Squeeze in reverse order
result = ops.squeeze(result, ax)
output.tensor_value = result
def eagerxpr(self, args: list[Array], output: Array) -> None:
"""Eager execution using NumPy slicing."""
input_array = args[0].to_numpy()
# Build numpy slice tuple
# Need to account for batch_dims - slicing only applies to shape dimensions
numpy_slices = []
# Add full slices for batch dimensions
for _ in range(len(args[0].batch_dims)):
numpy_slices.append(slice(None))
# Add the actual slices for shape dimensions
for i in range(len(args[0].shape)):
if i < len(self.slices):
numpy_slices.append(self.slices[i])
else:
numpy_slices.append(slice(None)) # Full slice for remaining dimensions
result = input_array[tuple(numpy_slices)]
# Apply squeezing for JAX-compatible behavior
if self.squeeze_axes:
# Adjust squeeze axes to account for batch dimensions
squeeze_axes_adjusted = [
ax + len(args[0].batch_dims) for ax in self.squeeze_axes
]
result = np.squeeze(result, axis=tuple(squeeze_axes_adjusted))
output.impl_(result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
"""Vector-Jacobian product rule for array slice."""
# If we squeezed dimensions, we need to unsqueeze the cotangent first
if self.squeeze_axes:
from ..ops.view import unsqueeze
# Unsqueeze in the positions that were squeezed
unsqueeze_axes = self.squeeze_axes.copy()
cotangent_unsqueezed = unsqueeze(cotangent, unsqueeze_axes)
else:
cotangent_unsqueezed = cotangent
return [pad(cotangent_unsqueezed, self.slices, primals[0].shape)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
"""Jacobian-vector product rule for array slice."""
# Apply the same slicing and squeezing to tangents
op = ArraySliceOp(self.slices, self.squeeze_axes)
return op.forward(tangents[0])
[docs]
def array_slice(
arg: Array, slices: list[slice], squeeze_axes: list[int] | None = None
) -> Array:
"""Slice an array along specified dimensions.
Args:
arg: Input array to slice
slices: List of slice objects defining the slicing for each dimension
squeeze_axes: List of axes that should be squeezed (for JAX compatibility)
Returns:
Sliced array
"""
op = ArraySliceOp(slices, squeeze_axes)
return op.forward(arg)
def split(arg: Array, sizes: list[int], axis: int = 0) -> list[Array]:
"""Split an array into multiple sub-arrays along a specified axis.
Args:
arg: Input array to split
sizes: List of sizes for each split along the specified axis
axis: Axis along which to split the array (default: 0)
Returns:
List of sub-arrays resulting from the split
"""
if not sizes:
raise ValueError("Sizes list must not be empty")
if axis < 0:
axis += len(arg.shape)
if axis < 0 or axis >= len(arg.shape):
raise ValueError(
f"Axis {axis} is out of bounds for array with {len(arg.shape)} dimensions"
)
# Compute the total size along the specified axis
total_size = sum(sizes)
if total_size != arg.shape[axis]:
raise ValueError(
f"Total size {total_size} along axis {axis} does not match input shape {arg.shape[axis]}"
)
# Create slices for each split
slices = []
idx = 0
for size in sizes:
slices.append(slice(idx, idx + size))
idx += size
# Create the result arrays
results = []
for s in slices:
slice_obj = [slice(None)] * len(arg.shape) # Full slice for all dimensions
slice_obj[axis] = s # Set the slice for the specified axis
results.append(array_slice(arg, slice_obj))
return results
class PadOp(Operation):
"""Inverse slice operation - places a smaller array into a larger zero-filled array."""
def __init__(self, slices: list[slice], target_shape: Shape):
# Convert slices to string representation for name
slice_strs = []
for s in slices:
start = s.start if s.start is not None else ""
stop = s.stop if s.stop is not None else ""
step = s.step if s.step is not None else ""
if step and step != 1:
slice_strs.append(f"{start}:{stop}:{step}")
else:
slice_strs.append(f"{start}:{stop}")
super().__init__(f"pad[{','.join(slice_strs)}]")
self.slices = slices
self.target_shape = target_shape
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compute output shape for inverse slice operation."""
if len(input_shapes) != 1:
raise ValueError(
f"Inverse slice operation requires 1 input shape, got {len(input_shapes)}"
)
# Validate that applying slices to target_shape would yield input_shape
input_shape = input_shapes[0]
# Simulate slicing target_shape with self.slices to verify consistency
expected_shape = []
for i, dim_size in enumerate(self.target_shape):
if i < len(self.slices):
s = self.slices[i]
start = s.start if s.start is not None else 0
stop = s.stop if s.stop is not None else dim_size
step = s.step if s.step is not None else 1
# Handle step sizes - now supported!
# if step != 1:
# raise NotImplementedError(
# "Stepped slicing not yet supported in pad"
# )
# Handle negative indices
if start < 0:
start = max(0, dim_size + start)
if stop < 0:
stop = max(0, dim_size + stop)
# Clamp to valid range
start = max(0, min(start, dim_size))
stop = max(start, min(stop, dim_size))
# Calculate output size for this dimension, accounting for step
if step == 1:
output_size = stop - start
else:
# For stepped slicing: number of elements = ceil((stop - start) / step)
output_size = (stop - start + step - 1) // step
expected_shape.append(output_size)
else:
# No slice for this dimension, keep original size
expected_shape.append(dim_size)
expected_shape = tuple(expected_shape)
if expected_shape != input_shape:
raise ValueError(
f"Slicing target_shape {self.target_shape} with {self.slices} "
f"would produce shape {expected_shape}, but input has shape {input_shape}"
)
return self.target_shape
def maxpr(self, args: list[TensorValue], output: Array) -> None:
"""MAX graph implementation using range->reshape->broadcast->slice->scatter approach."""
import numpy as np
input_tensor = args[0]
# Step 1: Calculate total elements in output shape
total_elements = int(np.prod(output.shape))
# Step 2: Create flat index tensor using ops.range with int32 dtype
flat_indices = ops.range(0, total_elements, 1, dtype=DType.int32)
# Step 3: Reshape to output shape
reshaped_indices = ops.reshape(flat_indices, output.shape)
# Step 4: Broadcast to include batch dims if needed
if output.batch_dims:
# Need to broadcast to batch_dims + output.shape
target_shape = list(output.batch_dims) + list(output.shape)
broadcasted_indices = ops.broadcast_to(reshaped_indices, target_shape)
else:
broadcasted_indices = reshaped_indices
# Step 5: Slice the index tensor using self.slices to get target indices
slice_indices = []
# Add full slices for batch dimensions
for _ in range(len(output.batch_dims)):
slice_indices.append(slice(None))
# Add the actual slices for shape dimensions
for s in self.slices:
slice_indices.append(slice(s.start, s.stop, s.step))
# Add full slices for any remaining dimensions
for _ in range(len(self.slices), len(output.shape)):
slice_indices.append(slice(None))
# Slice to get the indices where input should go
sliced_indices = ops.slice_tensor(broadcasted_indices, slice_indices)
# Step 6: Flatten the sliced indices
flattened_indices = ops.reshape(sliced_indices, [-1])
# Step 7: Create flat zero tensor for scattering
total_output_elements = int(
np.prod(list(output.batch_dims) + list(output.shape))
)
from max.graph import DeviceRef
zero_scalar = ops.constant(
0.0, dtype=output.dtype, device=DeviceRef.from_device(output.device)
)
flat_zeros = ops.broadcast_to(zero_scalar, [total_output_elements])
# Step 8: Flatten input tensor
input_flattened = ops.reshape(input_tensor, [-1])
# Step 9: Use scatter to place input values at target indices
# scatter(input, updates, indices, axis) - scatter along axis=0 (first axis) of flat tensor
scattered_flat = ops.scatter(
flat_zeros, input_flattened, flattened_indices, axis=0
)
# Step 10: Reshape result back to target shape
final_shape = list(output.batch_dims) + list(output.shape)
output.tensor_value = ops.reshape(scattered_flat, final_shape)
def eagerxpr(self, args: list[Array], output: Array) -> None:
"""Eager execution using NumPy."""
small_array = args[0]
# Create zero-filled target array
target_shape = output.batch_dims + output.shape
result_np = np.zeros(target_shape, dtype=small_array.to_numpy().dtype)
# Build slice indices (accounting for batch_dims)
slice_indices = []
# Add full slices for batch dimensions
for _ in range(len(output.batch_dims)):
slice_indices.append(slice(None))
# Add the actual slices for shape dimensions
slice_indices.extend(self.slices)
# Add full slices for any remaining dimensions
for _i in range(len(self.slices), len(output.shape)):
slice_indices.append(slice(None))
# Place small array into the target location
result_np[tuple(slice_indices)] = small_array.to_numpy()
output.impl_(result_np)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
"""VJP rule: slice the cotangent back to original size."""
# The VJP of pad is just a regular slice!
from nabla.ops.view import array_slice
return [array_slice(cotangent, self.slices)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
"""JVP rule: apply pad to tangents."""
return pad(tangents[0], self.slices, self.target_shape)
def forward(self, *args: Array) -> Array:
"""Forward pass for inverse slice operation."""
if len(args) != 1:
raise ValueError(
f"Inverse slice operation requires 1 argument, got {len(args)}"
)
input_array = args[0]
# Compute output properties
output_shape = self.compute_output_shape(input_array.shape)
# Create result array
res = Array(
shape=output_shape,
dtype=input_array.dtype,
device=input_array.device,
materialize=False,
name=self.name,
batch_dims=input_array.batch_dims,
)
# Set up computation
res.set_maxpr(self.maxpr)
res.add_arguments(input_array)
res.vjp_rule = self.vjp_rule
res.jvp_rule = self.jvp_rule
# Execute eager computation if needed
if not res.stage_realization:
self.eagerxpr([input_array], res)
return res
[docs]
def pad(arg: Array, slices: list[slice], target_shape: Shape) -> Array:
"""Place a smaller array into a larger zero-filled array at the location specified by slices.
This is the inverse operation of array slicing - given slices, a small array, and target shape,
it creates a larger array where the small array is placed at the sliced location
and everything else is zero.
Args:
arg: Input array (the smaller array to be placed)
slices: List of slice objects defining where to place the array
target_shape: The shape of the output array
Returns:
Larger array with input placed at sliced location, zeros elsewhere
"""
op = PadOp(slices, target_shape)
return op.forward(arg)
class SqueezeBatchDimsOp(ViewOperation):
"""Squeeze operation to remove batch dimensions of size 1."""
def __init__(self, axes: list[int] | None = None):
super().__init__(f"squeeze_batch_dims[axes={axes}]")
self.axes = sorted(axes) if axes is not None else []
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Shape stays the same for batch dimension operations."""
if len(input_shapes) != 1:
raise ValueError(
f"Squeeze batch dims operation requires 1 input shape, got {len(input_shapes)}"
)
return input_shapes[0]
def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
"""Compute output batch_dims for squeeze operation."""
if len(input_batch_dimss) != 1:
raise ValueError(
f"Squeeze batch dims operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
)
input_batch_dims = input_batch_dimss[0]
new_batch_dims = list(input_batch_dims)
for ax in self.axes:
if ax < -len(new_batch_dims) or ax >= len(new_batch_dims):
raise ValueError(
f"Axis {ax} is out of bounds for squeeze batch dims operation"
)
if input_batch_dims[ax] == 1:
new_batch_dims[ax] = None
else:
raise ValueError(
f"Cannot squeeze batch axis {ax} of size {input_batch_dims[ax]} (must be 1)"
)
new_batch_dims = [dim for dim in new_batch_dims if dim is not None]
return tuple(new_batch_dims)
def forward(self, *args: Array) -> Array:
"""Override forward to handle case where no squeezing needed."""
if len(args) != 1:
raise ValueError(
f"Squeeze batch dims operation requires 1 argument, got {len(args)}"
)
return super().forward(*args)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
"""MAX graph implementation using ops.squeeze."""
axes = [ax - len(output.shape) for ax in self.axes]
res = args[0]
for ax in axes:
res = ops.squeeze(res, ax)
output.tensor_value = res
def eagerxpr(self, args: list[Array], output: Array) -> None:
"""Eager execution using NumPy squeeze."""
axes = [ax - len(args[0].shape) for ax in self.axes]
np_result = np.squeeze(args[0].to_numpy(), axis=tuple(axes))
output.impl_(np_result)
def vjp_rule(
self, _primals: list[Array], cotangent: Array, _output: Array
) -> list[Array]:
"""VJP rule: unsqueeze the cotangent back to original batch dimensions."""
return [unsqueeze_batch_dims(cotangent, self.axes)]
def jvp_rule(
self, _primals: list[Array], tangents: list[Array], _output: Array
) -> Array:
"""JVP rule: apply squeeze to tangents."""
return squeeze_batch_dims(tangents[0], self.axes)
[docs]
def squeeze_batch_dims(arg: Array, axes: list[int] | None = None) -> Array:
"""Squeeze array by removing batch dimensions of size 1.
Args:
arg: Input array
axes: List of batch dimension axes to squeeze. If None, returns array unchanged.
Returns:
Array with specified batch dimensions of size 1 removed
"""
if axes is None:
return arg
# Convert to negative indices for consistency with batch dimension handling
axes = [ax if ax < 0 else -len(arg.batch_dims) + ax for ax in axes]
op = SqueezeBatchDimsOp(axes)
return op.forward(arg)
class UnsqueezeBatchDimsOp(ViewOperation):
"""Unsqueeze operation to add batch dimensions of size 1."""
def __init__(self, axes: list[int] | None = None):
super().__init__(f"unsqueeze_batch_dims[axes={axes}]")
self.axes = sorted(axes) if axes is not None else []
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Shape stays the same for batch dimension operations."""
if len(input_shapes) != 1:
raise ValueError(
f"Unsqueeze batch dims operation requires 1 input shape, got {len(input_shapes)}"
)
return input_shapes[0]
def compute_output_batch_dims(self, *input_batch_dimss: tuple) -> tuple:
"""Compute output batch_dims for unsqueeze operation."""
if len(input_batch_dimss) != 1:
raise ValueError(
f"Unsqueeze batch dims operation requires 1 input batch_dims, got {len(input_batch_dimss)}"
)
input_batch_dims = input_batch_dimss[0]
new_batch_dims = list(input_batch_dims)
for ax in self.axes:
if ax < -len(new_batch_dims) - 1:
raise ValueError(
f"Axis {ax} is out of bounds for unsqueeze batch dims operation"
)
if ax + 1 <= -1:
new_batch_dims.insert(ax + 1, 1)
else:
new_batch_dims.append(1)
return tuple(new_batch_dims)
def forward(self, *args: Array) -> Array:
"""Override forward to handle case where no unsqueezing needed."""
if len(args) != 1:
raise ValueError(
f"Unsqueeze batch dims operation requires 1 argument, got {len(args)}"
)
return super().forward(*args)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
"""MAX graph implementation using ops.unsqueeze."""
res = args[0]
# Use self.axes directly since it's already normalized to a list in __init__
# Adjust axes for batch dimensions
axes = [ax - len(output.shape) for ax in self.axes] if self.axes else []
for ax in axes:
res = ops.unsqueeze(res, ax)
output.tensor_value = res
def eagerxpr(self, args: list[Array], output: Array) -> None:
"""Eager execution using NumPy expand_dims."""
if self.axes:
# Apply expand_dims for each axis sequentially
np_result = args[0].to_numpy()
axes = [ax - len(args[0].shape) for ax in self.axes]
for ax in axes:
np_result = np.expand_dims(np_result, axis=ax)
else:
np_result = args[0].to_numpy()
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
"""VJP rule: squeeze the cotangent back to original batch dimensions."""
return [squeeze_batch_dims(cotangent, self.axes)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
"""JVP rule: apply unsqueeze to tangents."""
return unsqueeze_batch_dims(tangents[0], self.axes)
[docs]
def unsqueeze_batch_dims(arg: Array, axes: list[int] | None = None) -> Array:
"""Unsqueeze array by adding batch dimensions of size 1.
Args:
arg: Input array
axes: List of positions where to insert batch dimensions of size 1.
If None, returns array unchanged.
Returns:
Array with batch dimensions of size 1 added at specified positions
"""
if axes is None:
return arg
# Convert to negative indices for consistency with batch dimension handling
axes = [ax if ax < 0 else -len(arg.batch_dims) - 1 + ax for ax in axes]
op = UnsqueezeBatchDimsOp(axes)
return op.forward(arg)
# let's creata stack function which first creates a lsit of arrays wiht a new axis (via unsqueeze) and then concatenates them along that axis
[docs]
def stack(arrays: list[Array], axis: int = 0) -> Array:
"""Stack arrays along a new axis.
Args:
arrays: List of arrays to stack
axis: Axis along which to stack the arrays (default: 0)
Returns:
Stacked array
"""
if not arrays:
raise ValueError("Stack operation requires at least one array")
# Unsqueeze each array to add a new dimension at the specified axis
unsqueezed_arrays = [unsqueeze(array, [axis]) for array in arrays]
# Use concatenate to stack them along the new axis
return concatenate(unsqueezed_arrays, axis=axis)