# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Linear algebra operations."""
import numpy as np
from max.graph import TensorValue, ops
from ..core.array import Array
from ..utils.shape_utils import get_broadcasted_shape
from .operation import BinaryOperation
# Public API
__all__ = ["matmul"]
class MatMulOp(BinaryOperation):
"""Matrix multiplication operation with batching support."""
def __init__(self):
super().__init__("dot_general")
def forward(self, *args: Array) -> Array:
"""Forward pass for binary operations."""
if len(args) != 2:
raise ValueError(f"Binary operation requires 2 arguments, got {len(args)}")
# Move arrays to best device
from .operation import move_to_best_device
args = move_to_best_device(*args)
arg1, arg2 = args[0], args[1]
from ..ops.view import broadcast_batch_dims, broadcast_to, reshape
arg1_has_rank_1 = len(arg1.shape) == 1
arg2_has_rank_1 = len(arg2.shape) == 1
if arg1_has_rank_1:
arg1 = reshape(arg1, (1, arg1.shape[0]))
if arg2_has_rank_1:
arg2 = reshape(arg2, (arg2.shape[0], 1))
self._validate_inputs(arg1, arg2)
output_shape = self.compute_output_shape(arg1.shape, arg2.shape)
output_batch_dims = self.compute_output_batch_dims(
arg1.batch_dims, arg2.batch_dims
)
output_dtype = self.compute_output_dtype(arg1, arg2)
if arg1.traced:
arg1 = broadcast_to(arg1, output_shape[:-2] + arg1.shape[-2:])
arg1 = broadcast_batch_dims(arg1, output_batch_dims)
if arg2.traced:
arg2 = broadcast_to(arg2, output_shape[:-2] + arg2.shape[-2:])
arg2 = broadcast_batch_dims(arg2, output_batch_dims)
res = Array(
shape=output_shape,
dtype=output_dtype,
device=arg1.device,
materialize=False,
name=self.name,
batch_dims=output_batch_dims,
)
res.set_maxpr(self.maxpr)
res.add_arguments(arg1, arg2)
res.vjp_rule = self.vjp_rule
res.jvp_rule = self.jvp_rule
res.custom_kernel_path = self.custom_kernel_path()
if not res.stage_realization:
self.eagerxpr([arg1, arg2], res)
# If the inputs were rank 1, we need to reshape the output
if arg1_has_rank_1 and arg2_has_rank_1:
res = reshape(res, output_shape[:-2] + (1, 1))
elif arg1_has_rank_1:
res = reshape(res, output_shape[:-2] + (res.shape[1],))
elif arg2_has_rank_1:
res = reshape(res, output_shape[:-2] + (res.shape[0],))
return res
def compute_output_shape(self, *input_shapes: tuple) -> tuple:
"""Compute output shape for matrix multiplication with compatible signature."""
if len(input_shapes) != 2:
raise ValueError(
f"Matrix multiplication requires 2 input shapes, got {len(input_shapes)}"
)
shape1, shape2 = input_shapes[0], input_shapes[1]
if shape1[-1] != shape2[-2]:
raise ValueError(
f"Shapes {shape1} and {shape2} are not compatible for matrix multiplication"
)
return get_broadcasted_shape(
shape1,
shape2,
ignore_axes=[-2, -1],
replace_ignored_dims=[shape1[-2], shape2[-1]],
)
def _validate_inputs(self, arg1: Array, arg2: Array) -> None:
"""Validate matrix multiplication inputs."""
if not isinstance(arg1, Array) or not isinstance(arg2, Array):
raise TypeError("Both arguments must be Array instances")
if arg1.dtype != arg2.dtype:
raise ValueError(f"Dtypes {arg1.dtype} and {arg2.dtype} are incompatible")
if arg1.device != arg2.device:
raise ValueError(
f"Devices {arg1.device} and {arg2.device} are incompatible"
)
if arg1.shape[-1] != arg2.shape[-2]:
raise ValueError(
f"Shapes {arg1.shape} and {arg2.shape} are not compatible for matrix multiplication"
)
def maxpr(self, args: list[TensorValue], output: Array) -> None:
x_val, y_val = args[0], args[1]
x_shape = x_val.shape
y_shape = y_val.shape
output_shape = output.batch_dims + output.shape
if len(output_shape) <= 4:
output.tensor_value = ops.matmul(args[0], args[1])
else:
if x_shape[:-2] != y_shape[:-2]:
raise ValueError(
f"Shapes {x_shape} and {y_shape} are not compatible for matrix multiplication "
f"(batch dimensions mismatch: {x_shape[:-2]} vs {y_shape[:-2]})"
)
# now we can simpply reshape the args to a rank3 tensor respecitvely and then do a batche dmamtul on this one
batch_dims_x = [int(dim) for dim in x_shape[:-2]]
batch_dims_y = [int(dim) for dim in y_shape[:-2]]
new_shape_x = (
np.prod(batch_dims_x).item(),
int(x_shape[-2]),
int(x_shape[-1]),
)
new_shape_y = (
np.prod(batch_dims_y).item(),
int(y_shape[-2]),
int(y_shape[-1]),
)
x_val_b = ops.reshape(x_val, new_shape_x)
y_val_b = ops.reshape(y_val, new_shape_y)
matmul_result = ops.matmul(x_val_b, y_val_b)
reshaped_result = ops.reshape(
matmul_result,
tuple(args[0].shape[:-2])
+ (matmul_result.shape[-2], matmul_result.shape[-1]),
)
output.tensor_value = reshaped_result
def eagerxpr(self, args: list[Array], output: Array) -> None:
arg0_numpy = args[0].to_numpy()
arg1_numpy = args[1].to_numpy()
np_result = np.matmul(arg0_numpy, arg1_numpy)
output.impl_(np_result)
def vjp_rule(
self, primals: list[Array], cotangent: Array, output: Array
) -> list[Array]:
x, y = primals
from .view import transpose
return [matmul(cotangent, transpose(y)), matmul(transpose(x), cotangent)]
def jvp_rule(
self, primals: list[Array], tangents: list[Array], output: Array
) -> Array:
x, y = primals
tx, ty = tangents
from .binary import add
return add(matmul(x, ty), matmul(tx, y))
# Global operation instance for efficiency
_matmul_op = MatMulOp()
[docs]
def matmul(arg0, arg1) -> Array:
"""Matrix multiplication with broadcasting support."""
from .binary import _ensure_array
arg0 = _ensure_array(arg0)
arg1 = _ensure_array(arg1)
return _matmul_op.forward(arg0, arg1)
# # --- Convolution operations using im2col and col2im ---
# # Global operation instances
# _conv2d_op_cache = {}
# _conv2d_transpose_op_cache = {}
# # --- Helper functions for normalization ---
# def _normalize_tuple(value, n, name):
# if isinstance(value, int):
# return (value,) * n
# elif isinstance(value, tuple | list):
# if len(value) == n:
# return tuple(value)
# else:
# raise ValueError(
# f"{name} must be an int or a tuple of {n} ints, got {value}"
# )
# else:
# raise TypeError(
# f"{name} must be an int or a tuple, got {type(value)} for {name}"
# )
# def _normalize_padding_arg(padding_arg, name="padding"):
# if isinstance(padding_arg, int): # single int for all sides
# return ((padding_arg, padding_arg), (padding_arg, padding_arg))
# if isinstance(padding_arg, tuple | list):
# if len(padding_arg) == 2:
# if all(isinstance(x, int) for x in padding_arg): # (symmetric_H, symmetric_W)
# ph, pw = padding_arg
# return ((ph, ph), (pw, pw))
# elif all(isinstance(x, tuple | list) and len(x) == 2 and all(isinstance(y, int) for y in x) for x in padding_arg):
# # ((H_top, H_bottom), (W_left, W_right))
# return tuple(map(tuple, padding_arg))
# elif len(padding_arg) == 4 and all(isinstance(x, int) for x in padding_arg):
# # (H_top, H_bottom, W_left, W_right)
# pt, pb, pl, pr = padding_arg
# return ((pt, pb), (pl, pr))
# raise ValueError(
# f"{name} format is not recognized. Use int, (ph,pw), (pt,pb,pl,pr), or ((pt,pb),(pl,pr)). Got {padding_arg}"
# )
# def flip(x: Array, axis: int | tuple[int, ...]) -> Array:
# """
# Reverses the order of elements in an array along the given axes.
# This is an implementation of np.flip using fundamental slicing.
# """
# if isinstance(axis, int):
# axes_to_flip = (axis,)
# else:
# axes_to_flip = axis
# # Create a list of slice(None) objects, one for each dimension
# slicer = [slice(None)] * len(x.shape)
# # For each axis to be flipped, set the corresponding slice to ::-1
# for ax in axes_to_flip:
# slicer[ax] = slice(None, None, -1)
# # Use tuple slicing on the array. The Nabla Array class's __getitem__
# # must support this to be Python-idiomatic.
# return x[tuple(slicer)]
# def _conv2d_filter_gradient(
# x: Array, dy: Array, stride: tuple, dilation: tuple, padding: tuple, groups: int
# ) -> Array:
# """
# Computes `grad_W = conv(permute(x), permute(dy))` for a standard conv2d.
# Returns a filter gradient in HWIO layout.
# """
# from ..ops import view
# # Permute input x (NHWC) to be the data for the new conv: (Cin, H, W, N)
# x_perm = view.transpose(x, (3, 1, 2, 0))
# # Permute grad_output dy (NH'W'Cout) to be the filter for the new conv: (H', W', N, Cout)
# dy_perm = view.transpose(dy, (1, 2, 0, 3))
# # The new convolution's stride is the original's dilation, and vice versa.
# # This is a standard identity for this gradient formulation.
# grad_filter_permuted = conv2d(
# x_perm, dy_perm, stride=dilation, dilation=stride, padding=padding, groups=groups
# )
# # The output is (Cin, kH, kW, Cout). Permute back to standard filter layout.
# return view.transpose
# class Conv2DOp(BinaryOperation):
# # ... This class is likely correct, but its VJP depends on the functions below ...
# # Keep the version from my previous answer. The key fix is in Conv2DTransposeOp's VJP.
# # ... For completeness, I'll include it with the corrected VJP rule call ...
# """2D Convolution operation.
# Data Layout: NHWC (batch, height, width, in_channels)
# Filter Layout: HWIO (height, width, in_channels/groups, out_channels)
# """
# def __init__(self, stride, dilation, padding, groups):
# super().__init__("conv2d")
# self.stride = stride
# self.dilation = dilation
# self.padding = padding
# self.groups = groups
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# input_shape, filter_shape = input_shapes
# n, h_in, w_in, c_in = input_shape
# k_h, k_w, f_cin_div_g, f_cout = filter_shape
# if c_in != f_cin_div_g * self.groups:
# raise ValueError(
# f"Input channels ({c_in}) must match filter's effective input channels "
# f"({f_cin_div_g} * {self.groups} groups = {f_cin_div_g * self.groups}). "
# f"Input shape: {input_shape}, Filter shape: {filter_shape}"
# )
# (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right) = self.padding
# dil_h, dil_w = self.dilation
# s_h, s_w = self.stride
# h_out = (h_in + pad_h_top + pad_h_bottom - dil_h * (k_h - 1) - 1) // s_h + 1
# w_out = (w_in + pad_w_left + pad_w_right - dil_w * (k_w - 1) - 1) // s_w + 1
# c_out = f_cout
# if h_out <= 0 or w_out <= 0:
# raise ValueError(f"Computed non-positive output dimensions for Conv2D: {(n, h_out, w_out, c_out)}")
# return (n, h_out, w_out, c_out)
# def forward(self, *args: Array) -> Array:
# # Standard forward pass logic
# from .operation import move_to_best_device
# input_arr, filter_arr = move_to_best_device(*args)
# self._validate_inputs(input_arr, filter_arr)
# output_shape = self.compute_output_shape(input_arr.shape, filter_arr.shape)
# res = Array(
# shape=output_shape, dtype=self.compute_output_dtype(input_arr, filter_arr),
# device=input_arr.device, materialize=False, name=self.name,
# batch_dims=input_arr.batch_dims,
# )
# res.set_maxpr(self.maxpr)
# res.add_arguments(input_arr, filter_arr)
# res.vjp_rule = self.vjp_rule
# res.jvp_rule = self.jvp_rule
# if not res.stage_realization:
# self.eagerxpr([input_arr, filter_arr], res)
# return res
# def _validate_inputs(self, input_arr: Array, filter_arr: Array) -> None:
# if len(input_arr.shape) != 4 or len(filter_arr.shape) != 4:
# raise ValueError("Conv2D requires 4D input and filter tensors.")
# if input_arr.device != filter_arr.device:
# raise ValueError(f"Devices {input_arr.device} and {filter_arr.device} are incompatible")
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# input_val, filter_val = args
# (pt, pb), (pl, pr) = self.padding
# output.tensor_value = ops.conv2d(
# x=input_val, filter=filter_val, stride=self.stride,
# dilation=self.dilation, padding=(pt, pb, pl, pr), groups=self.groups
# )
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# input_arr, filter_arr = args
# input_torch = torch.from_numpy(np.transpose(input_arr.to_numpy(), (0, 3, 1, 2)))
# filter_torch = torch.from_numpy(np.transpose(filter_arr.to_numpy(), (3, 2, 0, 1)))
# (pad_h, _), (pad_w, _) = self.padding
# result_torch = F.conv2d(
# input=input_torch, weight=filter_torch, bias=None, stride=self.stride,
# padding=(pad_h, pad_w), dilation=self.dilation, groups=self.groups
# )
# result_nhwc = np.transpose(result_torch.numpy(), (0, 2, 3, 1))
# output.impl_(result_nhwc)
# def vjp_rule(self, primals: list[Array], cotangent: Array, output: Array) -> list[Array]:
# """VJP of Y = conv(X, W)"""
# input_arr, filter_arr = primals # filter_arr is HWIO
# # 1. grad_input = conv_transpose(dY, W_flipped_180)
# flipped_filter = flip(filter_arr, axis=(0, 1))
# # Filter for conv_transpose must be HWOI. Swap channels of our HWIO filter.
# filter_for_grad_input = flipped_filter.transpose((0, 1, 3, 2))
# # Calculate output_padding to restore original input shape
# h_in, w_in = input_arr.shape[1:3]
# h_out, w_out = cotangent.shape[1:3]
# k_h, k_w = filter_arr.shape[0:2]
# (pt, pb), (pl, pr) = self.padding
# sh, sw = self.stride
# dh, dw = self.dilation
# out_pad_h = h_in - ((h_out - 1) * sh - (pt + pb) + (k_h - 1) * dh + 1)
# out_pad_w = w_in - ((w_out - 1) * sw - (pl + pr) + (k_w - 1) * dw + 1)
# grad_input = conv2d_transpose(
# cotangent, filter_for_grad_input, stride=self.stride, dilation=self.dilation,
# padding=self.padding, output_padding=(max(0,out_pad_h), max(0,out_pad_w)), groups=self.groups
# )
# # 2. grad_filter = conv(permute(X), permute(dY))
# grad_filter = _conv2d_filter_gradient(
# input_arr, cotangent, self.stride, self.dilation, self.padding, self.groups
# )
# return [grad_input, grad_filter]
# def jvp_rule(self, primals: list[Array], tangents: list[Array], output: Array) -> Array:
# input_arr, filter_arr = primals
# input_tangent, filter_tangent = tangents
# from .binary import add
# res1 = conv2d(input_tangent, filter_arr, stride=self.stride, dilation=self.dilation, padding=self.padding, groups=self.groups)
# res2 = conv2d(input_arr, filter_tangent, stride=self.stride, dilation=self.dilation, padding=self.padding, groups=self.groups)
# return add(res1, res2)
# class Conv2DTransposeOp(BinaryOperation):
# # This is the class with the key fixes.
# # ... (Keep the __init__, compute_output_shape, forward, _validate_inputs, maxpr, eagerxpr from my PREVIOUS answer)
# # The only change is in the VJP RULE.
# def __init__(self, stride, dilation, padding, output_padding, groups):
# super().__init__("conv2d_transpose")
# self.stride = stride
# self.dilation = dilation
# self.padding = padding
# self.output_padding = output_padding
# self.groups = groups
# def compute_output_shape(self, *input_shapes: tuple) -> tuple:
# input_shape, filter_shape = input_shapes
# n, h_in, w_in, c_in = input_shape
# k_h, k_w, f_cout, f_cin_div_g = filter_shape
# if c_in != f_cin_div_g * self.groups:
# raise ValueError(
# f"Input channels ({c_in}) must match filter's effective input channels "
# f"({f_cin_div_g} * {self.groups} groups = {f_cin_div_g * self.groups}). "
# f"This is the 'I' in HWOI. "
# f"Input shape: {input_shape}, Filter shape: {filter_shape}"
# )
# (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right) = self.padding
# out_pad_h, out_pad_w = self.output_padding
# dil_h, dil_w = self.dilation
# s_h, s_w = self.stride
# h_out = (h_in - 1) * s_h - (pad_h_top + pad_h_bottom) + dil_h * (k_h - 1) + 1 + out_pad_h
# w_out = (w_in - 1) * s_w - (pad_w_left + pad_w_right) + dil_w * (k_w - 1) + 1 + out_pad_w
# c_out = f_cout
# if h_out <= 0 or w_out <= 0:
# raise ValueError(f"Computed non-positive output dimensions for Conv2DTranspose: {(n, h_out, w_out, c_out)}")
# return (n, h_out, w_out, c_out)
# def forward(self, *args: Array) -> Array:
# from .operation import move_to_best_device
# input_arr, filter_arr = move_to_best_device(*args)
# self._validate_inputs(input_arr, filter_arr)
# output_shape = self.compute_output_shape(input_arr.shape, filter_arr.shape)
# res = Array(
# shape=output_shape, dtype=self.compute_output_dtype(input_arr, filter_arr),
# device=input_arr.device, materialize=False, name=self.name,
# batch_dims=input_arr.batch_dims,
# )
# res.set_maxpr(self.maxpr)
# res.add_arguments(input_arr, filter_arr)
# res.vjp_rule = self.vjp_rule
# res.jvp_rule = self.jvp_rule
# if not res.stage_realization:
# self.eagerxpr([input_arr, filter_arr], res)
# return res
# def _validate_inputs(self, input_arr: Array, filter_arr: Array) -> None:
# if len(input_arr.shape) != 4 or len(filter_arr.shape) != 4:
# raise ValueError("Conv2DTranspose requires 4D input and filter tensors.")
# if input_arr.device != filter_arr.device:
# raise ValueError(f"Devices {input_arr.device} and {filter_arr.device} are incompatible")
# def maxpr(self, args: list[TensorValue], output: Array) -> None:
# input_val, filter_val = args
# (pt, pb), (pl, pr) = self.padding
# if self.groups > 1:
# from ..ops.view import split, concatenate
# input_chunks = split(input_val, self.groups, axis=3)
# filter_chunks = split(filter_val, self.groups, axis=3)
# output_chunks = []
# for i in range(self.groups):
# chunk_out = ops.conv2d_transpose(
# input_chunks[i], filter_chunks[i], stride=self.stride,
# dilation=self.dilation, padding=(pt, pb, pl, pr),
# output_paddings=self.output_padding
# )
# output_chunks.append(chunk_out)
# output.tensor_value = concatenate(output_chunks, axis=3)
# else:
# output.tensor_value = ops.conv2d_transpose(
# input_val, filter_val, stride=self.stride, dilation=self.dilation,
# padding=(pt, pb, pl, pr), output_paddings=self.output_padding
# )
# def eagerxpr(self, args: list[Array], output: Array) -> None:
# input_arr, filter_arr = args
# input_torch = torch.from_numpy(np.transpose(input_arr.to_numpy(), (0, 3, 1, 2)))
# filter_torch = torch.from_numpy(np.transpose(filter_arr.to_numpy(), (3, 2, 0, 1)))
# (pad_h, _), (pad_w, _) = self.padding
# result_torch = F.conv_transpose2d(
# input=input_torch, weight=filter_torch, bias=None, stride=self.stride,
# padding=(pad_h, pad_w), output_padding=self.output_padding,
# groups=self.groups, dilation=self.dilation
# )
# result_nhwc = np.transpose(result_torch.numpy(), (0, 2, 3, 1))
# output.impl_(result_nhwc)
# def vjp_rule(self, primals: list[Array], cotangent: Array, output: Array) -> list[Array]:
# """VJP of Y = conv_transpose(X, W)"""
# input_arr, filter_arr = primals # filter_arr is HWOI
# # 1. grad_input = conv(dY, W_flipped_180)
# flipped_filter = flip(filter_arr, axis=(0, 1))
# # Filter for conv2d must be HWIO. Swap channels of our HWOI filter.
# filter_for_grad_input = flipped_filter.transpose((0, 1, 3, 2))
# grad_input = conv2d(
# cotangent, filter_for_grad_input, stride=self.stride,
# dilation=self.dilation, padding=self.padding, groups=self.groups
# )
# # 2. grad_filter = conv(permute(dY), permute(X))
# # Note the swapped arguments compared to the conv2d VJP.
# grad_filter_HWIO = _conv2d_filter_gradient(
# cotangent, input_arr, self.stride, self.dilation, self.padding, self.groups
# )
# # The helper returns HWIO. The gradient must match the primal filter's HWOI layout.
# grad_filter = grad_filter_HWIO.transpose((0, 1, 3, 2))
# return [grad_input, grad_filter]
# def jvp_rule(self, primals: list[Array], tangents: list[Array], output: Array) -> Array:
# input_arr, filter_arr = primals
# input_tangent, filter_tangent = tangents
# from .binary import add
# res1 = conv2d_transpose(
# input_tangent, filter_arr, stride=self.stride, dilation=self.dilation,
# padding=self.padding, output_padding=self.output_padding, groups=self.groups)
# res2 = conv2d_transpose(
# input_arr, filter_tangent, stride=self.stride, dilation=self.dilation,
# padding=self.padding, output_padding=self.output_padding, groups=self.groups)
# return add(res1, res2)
# def conv2d(
# input_arr: Array, filter_arr: Array, stride=(1, 1),
# dilation=(1, 1), padding=0, groups=1
# ) -> Array:
# """Applies a 2D convolution."""
# norm_stride = _normalize_tuple(stride, 2, "stride")
# norm_dilation = _normalize_tuple(dilation, 2, "dilation")
# norm_padding = _normalize_padding_arg(padding, "padding")
# cache_key = (norm_stride, norm_dilation, norm_padding, groups)
# if cache_key not in _conv2d_op_cache:
# _conv2d_op_cache[cache_key] = Conv2DOp(norm_stride, norm_dilation, norm_padding, groups)
# op = _conv2d_op_cache[cache_key]
# return op.forward(input_arr, filter_arr)
# def conv2d_transpose(
# input_arr: Array, filter_arr: Array, stride=(1, 1),
# dilation=(1, 1), padding=0, output_padding=0, groups=1
# ) -> Array:
# """Applies a 2D transposed convolution."""
# norm_stride = _normalize_tuple(stride, 2, "stride")
# norm_dilation = _normalize_tuple(dilation, 2, "dilation")
# norm_padding = _normalize_padding_arg(padding, "padding")
# norm_output_padding = _normalize_tuple(output_padding, 2, "output_padding")
# cache_key = (norm_stride, norm_dilation, norm_padding, norm_output_padding, groups)
# if cache_key not in _conv2d_transpose_op_cache:
# _conv2d_transpose_op_cache[cache_key] = Conv2DTransposeOp(
# norm_stride, norm_dilation, norm_padding, norm_output_padding, groups
# )
# op = _conv2d_transpose_op_cache[cache_key]
# return op.forward(input_arr, filter_arr)
# # ===----------------------------------------------------------------------=== #
# # Nabla 2025
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
# # ===----------------------------------------------------------------------=== #
# """Numpy-based convolution utilities for eager execution."""
# from typing import Union
# import numpy as np
# def im2col(
# input_data: np.ndarray,
# filter_h: int,
# filter_w: int,
# stride: Union[int, tuple[int, int]] = 1,
# dilation: Union[int, tuple[int, int]] = 1,
# pad: Union[int, tuple[int, int]] = 0,
# ) -> np.ndarray:
# """
# Convert input data to column matrix for convolution.
# Parameters:
# -----------
# input_data : ndarray
# Input data with shape (N, C, H, W)
# filter_h : int
# Filter height
# filter_w : int
# Filter width
# stride : int or tuple
# Stride for convolution
# dilation : int or tuple
# Dilation for convolution
# pad : int or tuple
# Padding for input
# Returns:
# --------
# col : ndarray
# Column matrix with shape (N, C, filter_h, filter_w, out_h, out_w)
# """
# n, c, h, w = input_data.shape
# # Handle stride and dilation as tuples
# if isinstance(stride, int):
# stride_h, stride_w = stride, stride
# else:
# stride_h, stride_w = stride
# if isinstance(dilation, int):
# dilation_h, dilation_w = dilation, dilation
# else:
# dilation_h, dilation_w = dilation
# if isinstance(pad, int):
# pad_h, pad_w = pad, pad
# else:
# pad_h, pad_w = pad
# out_h = (h + 2 * pad_h - dilation_h * (filter_h - 1) - 1) // stride_h + 1
# out_w = (w + 2 * pad_w - dilation_w * (filter_w - 1) - 1) // stride_w + 1
# img = np.pad(
# input_data, ((0, 0), (0, 0), (pad_h, pad_h), (pad_w, pad_w)), mode="constant"
# )
# col = np.ndarray((n, c, filter_h, filter_w, out_h, out_w), dtype=input_data.dtype)
# for j in range(filter_h):
# j_lim = j * dilation_h + stride_h * out_h
# for i in range(filter_w):
# i_lim = i * dilation_w + stride_w * out_w
# col[:, :, j, i, :, :] = img[
# :,
# :,
# j * dilation_h : j_lim : stride_h,
# i * dilation_w : i_lim : stride_w,
# ]
# return col
# def col2im(
# col: np.ndarray,
# input_shape: tuple[int, int, int, int],
# filter_h: int,
# filter_w: int,
# stride: Union[int, tuple[int, int]] = 1,
# dilation: Union[int, tuple[int, int]] = 1,
# pad: Union[int, tuple[int, int]] = 0,
# ) -> np.ndarray:
# """
# Convert column matrix back to input data shape.
# Parameters:
# -----------
# col : ndarray
# Column matrix with shape (N, C, filter_h, filter_w, out_h, out_w)
# input_shape : tuple
# Original input shape (N, C, H, W)
# filter_h : int
# Filter height
# filter_w : int
# Filter width
# stride : int or tuple
# Stride for convolution
# dilation : int or tuple
# Dilation for convolution
# pad : int or tuple
# Padding for input
# Returns:
# --------
# img : ndarray
# Reconstructed input data with shape (N, C, H, W)
# """
# n, c, h, w = input_shape
# # Handle stride and dilation as tuples
# if isinstance(stride, int):
# stride_h, stride_w = stride, stride
# else:
# stride_h, stride_w = stride
# if isinstance(dilation, int):
# dilation_h, dilation_w = dilation, dilation
# else:
# dilation_h, dilation_w = dilation
# if isinstance(pad, int):
# pad_h, pad_w = pad, pad
# else:
# pad_h, pad_w = pad
# out_h = (h + 2 * pad_h - dilation_h * (filter_h - 1) - 1) // stride_h + 1
# out_w = (w + 2 * pad_w - dilation_w * (filter_w - 1) - 1) // stride_w + 1
# img = np.zeros(
# (n, c, h + 2 * pad_h + stride_h - 1, w + 2 * pad_w + stride_w - 1),
# dtype=col.dtype,
# )
# for j in range(filter_h):
# j_lim = j * dilation_h + stride_h * out_h
# for i in range(filter_w):
# i_lim = i * dilation_w + stride_w * out_w
# img[
# :,
# :,
# j * dilation_h : j_lim : stride_h,
# i * dilation_w : i_lim : stride_w,
# ] += col[:, :, j, i, :, :]
# return img[:, :, pad_h : h + pad_h, pad_w : w + pad_w]
# def conv2d(input_data, filters, dilation=(1, 1), stride=(1, 1), padding=(0, 0)):
# """
# 2D convolution using im2col method.
# Parameters:
# -----------
# input_data : ndarray
# Input data with shape (N, C_in, H, W)
# filters : ndarray
# Filters with shape (C_out, C_in, filter_h, filter_w)
# dilation : tuple
# Dilation factors (dilation_h, dilation_w)
# stride : tuple
# Stride values (stride_h, stride_w)
# padding : tuple
# Padding values (pad_h, pad_w)
# Returns:
# --------
# output : ndarray
# Convolution output with shape (N, C_out, out_h, out_w)
# """
# n, c_in, h, w = input_data.shape
# c_out, c_in_f, filter_h, filter_w = filters.shape
# assert c_in == c_in_f, f"Input channels {c_in} != filter input channels {c_in_f}"
# # Calculate output dimensions
# pad_h, pad_w = padding
# stride_h, stride_w = stride
# dilation_h, dilation_w = dilation
# out_h = (h + 2 * pad_h - dilation_h * (filter_h - 1) - 1) // stride_h + 1
# out_w = (w + 2 * pad_w - dilation_w * (filter_w - 1) - 1) // stride_w + 1
# # Convert input to column matrix
# col = im2col(input_data, filter_h, filter_w, stride, dilation, padding)
# col = col.transpose(0, 4, 5, 1, 2, 3).reshape(n * out_h * out_w, -1)
# # Reshape filters
# w_col = filters.reshape(c_out, -1)
# # Perform convolution via matrix multiplication
# out = np.dot(col, w_col.T)
# out = out.reshape(n, out_h, out_w, c_out).transpose(0, 3, 1, 2)
# return out
# def transposed_conv2d(
# input_data,
# filters,
# dilation=(1, 1),
# stride=(1, 1),
# padding=(0, 0),
# output_padding=(0, 0),
# ):
# """
# 2D transposed convolution using JAX-compatible algorithm.
# JAX's conv_transpose implementation:
# 1. Upsample input by inserting (stride-1) zeros between elements
# 2. Apply regular convolution with effective padding
# For transposed convolution, the effective padding is:
# effective_pad = kernel_size - 1 - original_pad
# Parameters:
# -----------
# input_data : ndarray
# Input data with shape (N, C_in, H, W)
# filters : ndarray
# Filters with shape (C_out, C_in, filter_h, filter_w)
# dilation : tuple
# Dilation factors (dilation_h, dilation_w)
# stride : tuple
# Stride values (stride_h, stride_w)
# padding : tuple
# Original padding values (pad_h, pad_w) from the forward convolution
# output_padding : tuple
# Output padding values (out_pad_h, out_pad_w) - not used in JAX-compatible mode
# Returns:
# --------
# output : ndarray
# Transposed convolution output
# """
# n, c_in, h, w = input_data.shape
# c_out, c_in_f, filter_h, filter_w = filters.shape
# assert c_in == c_in_f, f"Input channels {c_in} != filter input channels {c_in_f}"
# pad_h, pad_w = padding
# stride_h, stride_w = stride
# dilation_h, dilation_w = dilation
# # Step 1: Upsample input by inserting (stride-1) zeros between elements
# if stride_h > 1 or stride_w > 1:
# # Calculate upsampled dimensions
# upsampled_h = h + (h - 1) * (stride_h - 1)
# upsampled_w = w + (w - 1) * (stride_w - 1)
# # Create upsampled array filled with zeros
# upsampled = np.zeros(
# (n, c_in, upsampled_h, upsampled_w), dtype=input_data.dtype
# )
# # Insert original values at strided positions
# upsampled[:, :, ::stride_h, ::stride_w] = input_data
# else:
# # No upsampling needed for stride=1
# upsampled = input_data
# # Step 2: Calculate effective padding for transposed convolution
# # For transposed conv, if original conv had padding P and kernel size K,
# # the effective padding for the underlying regular conv is (K-1-P)
# effective_pad_h = filter_h - 1 - pad_h
# effective_pad_w = filter_w - 1 - pad_w
# # Step 3: Apply regular convolution with effective padding
# # Use stride=1 since upsampling already handled the stride effect
# result = conv2d(
# upsampled,
# filters,
# dilation=dilation,
# stride=(1, 1),
# padding=(effective_pad_h, effective_pad_w),
# )
# # Step 4: Apply output_padding if specified
# # Output padding adds zeros to the right and bottom of the output
# out_pad_h, out_pad_w = output_padding
# if out_pad_h > 0 or out_pad_w > 0:
# n, c_out, h_out, w_out = result.shape
# padded_result = np.zeros(
# (n, c_out, h_out + out_pad_h, w_out + out_pad_w), dtype=result.dtype
# )
# padded_result[:, :, :h_out, :w_out] = result
# result = padded_result
# return result