Source code for nabla.utils.shape_utils
# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Broadcasting and shape manipulation utilities."""
Shape = tuple[int, ...]
[docs]
def get_broadcasted_shape(
shape1: Shape,
shape2: Shape,
ignore_axes: list[int] | None = None,
replace_ignored_dims: list[int] | None = None,
) -> Shape:
"""
Compute the broadcasted shape of two input shapes.
Args:
shape1: First input shape
shape2: Second input shape
ignore_axes: Axes to ignore during broadcasting
replace_ignored_dims: Replacement dimensions for ignored axes
Returns:
The broadcasted output shape
Raises:
ValueError: If shapes cannot be broadcast together
"""
if ignore_axes is None:
ignore_axes = []
if replace_ignored_dims is None:
replace_ignored_dims = []
if len(replace_ignored_dims) != len(ignore_axes):
raise ValueError(
"replace_ignored_dims must have the same length as ignore_axes"
)
s1_len = len(shape1)
s2_len = len(shape2)
max_rank = max(s1_len, s2_len)
# Initialize result shape with 1s (common default for broadcasting)
res_shape_list = [1] * max_rank
# Normalize ignore_axes to positive indices and store replacement values
normalized_ignored_map = {}
for i, axis_spec in enumerate(ignore_axes):
replacement_dim = replace_ignored_dims[i]
# Validate and normalize the axis_spec relative to max_rank
if not (-max_rank <= axis_spec < max_rank):
raise ValueError(
f"ignore_axis {axis_spec} is out of bounds for max_rank {max_rank}"
)
normalized_idx = axis_spec if axis_spec >= 0 else max_rank + axis_spec
normalized_ignored_map[normalized_idx] = replacement_dim
res_shape_list[normalized_idx] = replacement_dim
# Pad original shapes with leading 1s to align them to max_rank
padded_shape1_list = [1] * (max_rank - s1_len) + list(shape1)
padded_shape2_list = [1] * (max_rank - s2_len) + list(shape2)
# Perform broadcasting for non-ignored axes
for i in range(max_rank):
if i in normalized_ignored_map:
# This dimension's value is already set by replace_ignored_dims
continue
d1 = padded_shape1_list[i]
d2 = padded_shape2_list[i]
if d1 == d2:
res_shape_list[i] = d1
elif d1 == 1:
res_shape_list[i] = d2
elif d2 == 1:
res_shape_list[i] = d1
else:
# Dimensions are different and neither is 1, broadcasting error
raise ValueError(
f"Shapes {shape1} and {shape2} cannot be broadcast at dimension index {i} "
f"(0-indexed from left of max_rank {max_rank} shape). "
f"Padded values at this index are {d1} (from shape1) and {d2} (from shape2)."
)
return tuple(res_shape_list)