Source code for nabla.ops.binary

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from __future__ import annotations

import numpy as np
from max.dtype import DType
from max.graph import TensorValue, ops

from ..core.array import Array
from .operation import BinaryOperation

# Public API
__all__ = [
    "add",
    "mul",
    "sub",
    "div",
    "floordiv",
    "mod",
    "pow",
    "greater_equal",
    "equal",
    "not_equal",
    "maximum",
    "minimum",
]


def _ensure_array(value) -> Array:
    """Convert scalar values to Arrays."""
    if isinstance(value, Array):
        return value
    elif isinstance(value, int | float):
        from .creation import array

        return array(value)
    else:
        raise TypeError(f"Cannot convert {type(value)} to Array")


class AddOp(BinaryOperation):
    """Addition operation."""

    def __init__(self):
        super().__init__("add")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.add(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = np.add(args[0].to_numpy(), args[1].to_numpy())
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        return [cotangent, cotangent]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        return add(tangents[0], tangents[1])


class MulOp(BinaryOperation):
    """Multiplication operation."""

    def __init__(self):
        super().__init__("mul")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.mul(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = np.multiply(args[0].to_numpy(), args[1].to_numpy())
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        return [mul(cotangent, primals[1]), mul(cotangent, primals[0])]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        return add(mul(primals[0], tangents[1]), mul(primals[1], tangents[0]))


class SubOp(BinaryOperation):
    """Subtraction operation."""

    def __init__(self):
        super().__init__("sub")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.sub(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = np.subtract(args[0].to_numpy(), args[1].to_numpy())
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        from .unary import negate

        return [cotangent, negate(cotangent)]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        return sub(tangents[0], tangents[1])


class DivOp(BinaryOperation):
    """Division operation."""

    def __init__(self):
        super().__init__("div")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.div(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = np.divide(args[0].to_numpy(), args[1].to_numpy())
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        from .unary import negate

        x, y = primals
        cotangent_x = div(cotangent, y)
        cotangent_y = negate(div(mul(cotangent, x), mul(y, y)))
        return [cotangent_x, cotangent_y]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        from .unary import negate

        x, y = primals
        dx, dy = tangents
        term1 = div(dx, y)
        term2 = negate(div(mul(x, dy), mul(y, y)))
        return add(term1, term2)


class PowerOp(BinaryOperation):
    """Power operation (x^y)."""

    def __init__(self):
        super().__init__("pow")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.pow(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = np.pow(args[0].to_numpy(), args[1].to_numpy())
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        from .unary import log

        x, y = primals
        cotangent_x = mul(mul(cotangent, y), div(output, x))
        cotangent_y = mul(mul(cotangent, output), log(x))

        return [cotangent_x, cotangent_y]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        from .unary import log

        x, y = primals
        dx, dy = tangents
        term1 = mul(mul(y, div(output, x)), dx)
        term2 = mul(mul(output, log(x)), dy)

        return add(term1, term2)


class GreaterEqualOp(BinaryOperation):
    """Greater than or equal to operation."""

    def __init__(self):
        super().__init__("greater_equal")

    def compute_output_dtype(self, arg1: Array, arg2: Array) -> DType:
        """Comparison operations return bool dtype."""
        return DType.bool

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.greater_equal(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        import numpy as np

        np_result = np.greater_equal(args[0].to_numpy(), args[1].to_numpy())

        # Ensure result is always a numpy array
        if np.isscalar(np_result):
            np_result = np.array(np_result)

        # WORKAROUND: MAX library bug with scalar boolean tensors
        # The MAX tensor library fails when creating scalar boolean tensors
        # due to a bug in the _view method (line 49 in tensor.py)
        if np_result.shape == () and np_result.dtype == bool:
            # Convert scalar boolean to 1D boolean array, create tensor
            # The output will appear as scalar but be stored as 1D internally
            np_result_1d = np.array([np_result.item()], dtype=bool)
            output.impl_(np_result_1d)
            # Override the shape to appear as scalar
            output.shape = ()
        else:
            # Normal path for non-scalar boolean or any non-boolean results
            output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        from .creation import zeros_like

        return [
            zeros_like(cotangent).astype(primals[0].dtype),
            zeros_like(cotangent).astype(primals[1].dtype),
        ]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        from .creation import zeros_like

        return zeros_like(tangents[0]).astype(output.dtype)


class MaximumOp(BinaryOperation):
    """Element-wise maximum operation."""

    def __init__(self):
        super().__init__("maximum")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.max(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = np.maximum(args[0].to_numpy(), args[1].to_numpy())
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        # Gradient flows to the larger input
        # For equal inputs, we split the gradient (JAX convention)
        x, y = primals
        x_greater = greater_equal(x, y)
        y_greater = greater_equal(y, x)

        # Cast boolean masks to float for multiplication
        from ..ops.unary import cast

        x_mask = cast(x_greater, cotangent.dtype)
        y_mask = cast(y_greater, cotangent.dtype)

        # When x == y, both masks are True, so we need to split the gradient
        both_equal = mul(x_mask, y_mask)
        x_only = sub(x_mask, both_equal)
        y_only = sub(y_mask, both_equal)

        # Split gradient equally when inputs are equal
        half_cotangent = mul(cotangent, 0.5)

        grad_x = add(mul(cotangent, x_only), mul(half_cotangent, both_equal))
        grad_y = add(mul(cotangent, y_only), mul(half_cotangent, both_equal))

        return [grad_x, grad_y]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        x, y = primals
        dx, dy = tangents
        x_greater = greater_equal(x, y)

        # Cast boolean mask to float for multiplication
        from ..ops.unary import cast

        x_mask = cast(x_greater, dx.dtype)
        y_mask = sub(1.0, x_mask)

        return add(mul(dx, x_mask), mul(dy, y_mask))


class MinimumOp(BinaryOperation):
    """Element-wise minimum operation."""

    def __init__(self):
        super().__init__("minimum")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.min(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = np.minimum(args[0].to_numpy(), args[1].to_numpy())
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        # Gradient flows to the smaller input
        # For equal inputs, we split the gradient (JAX convention)
        x, y = primals
        x_less_equal = greater_equal(y, x)  # x <= y
        y_less_equal = greater_equal(x, y)  # y <= x

        # Cast boolean masks to float for multiplication
        from ..ops.unary import cast

        x_mask = cast(x_less_equal, cotangent.dtype)
        y_mask = cast(y_less_equal, cotangent.dtype)

        # When x == y, both masks are True, so we need to split the gradient
        both_equal = mul(x_mask, y_mask)
        x_only = sub(x_mask, both_equal)
        y_only = sub(y_mask, both_equal)

        # Split gradient equally when inputs are equal
        half_cotangent = mul(cotangent, 0.5)

        grad_x = add(mul(cotangent, x_only), mul(half_cotangent, both_equal))
        grad_y = add(mul(cotangent, y_only), mul(half_cotangent, both_equal))

        return [grad_x, grad_y]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        x, y = primals
        dx, dy = tangents
        x_less_equal = greater_equal(y, x)  # x <= y

        # Cast boolean mask to float for multiplication
        from ..ops.unary import cast

        x_mask = cast(x_less_equal, dx.dtype)
        y_mask = sub(1.0, x_mask)

        return add(mul(dx, x_mask), mul(dy, y_mask))


class EqualOp(BinaryOperation):
    """Element-wise equality comparison operation."""

    def __init__(self):
        super().__init__("equal")

    def compute_output_dtype(self, arg0: Array, arg1: Array) -> DType:
        """Equal returns boolean dtype."""
        return DType.bool

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.equal(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        import numpy as np

        arg0_np = args[0].to_numpy()
        arg1_np = args[1].to_numpy()
        np_result = arg0_np == arg1_np

        # Ensure result is always a numpy array
        if np.isscalar(np_result):
            np_result = np.array(np_result)

        # WORKAROUND: MAX library bug with scalar boolean tensors
        # The MAX tensor library fails when creating scalar boolean tensors
        # Convert scalar boolean to float32 to avoid the bug
        if np_result.shape == () and np_result.dtype == bool:
            # Convert scalar boolean to float32 scalar (1.0 or 0.0)
            float_result = np_result.astype(np.float32)
            output.impl_(float_result)
            # Update output dtype to reflect what we actually stored
            output.dtype = DType.float32
        else:
            # Normal path for non-scalar boolean or any non-boolean results
            output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        from .creation import zeros_like

        return [
            zeros_like(cotangent).astype(primals[0].dtype),
            zeros_like(cotangent).astype(primals[1].dtype),
        ]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        from .creation import zeros_like

        return zeros_like(tangents[0])


class NotEqualOp(BinaryOperation):
    """Element-wise not-equal comparison operation."""

    def __init__(self):
        super().__init__("not_equal")

    def compute_output_dtype(self, arg0: Array, arg1: Array) -> DType:
        """Not equal returns boolean dtype."""
        return DType.bool

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.not_equal(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        import numpy as np

        arg0_np = args[0].to_numpy()
        arg1_np = args[1].to_numpy()
        np_result = arg0_np != arg1_np

        # Ensure result is always a numpy array
        if np.isscalar(np_result):
            np_result = np.array(np_result)

        # WORKAROUND: MAX library bug with scalar boolean tensors
        # The MAX tensor library fails when creating scalar boolean tensors
        # due to a bug in the _view method (line 49 in tensor.py)
        if np_result.shape == () and np_result.dtype == bool:
            # Convert scalar boolean to 1D boolean array, create tensor
            # The output will appear as scalar but be stored as 1D internally
            np_result_1d = np.array([np_result.item()], dtype=bool)
            output.impl_(np_result_1d)
            # Override the shape to appear as scalar
            output.shape = ()
        else:
            # Normal path for non-scalar boolean or any non-boolean results
            output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        from .creation import zeros_like

        return [
            zeros_like(cotangent).astype(primals[0].dtype),
            zeros_like(cotangent).astype(primals[1].dtype),
        ]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        from .creation import zeros_like

        return zeros_like(tangents[0]).astype(output.dtype)


class ModOp(BinaryOperation):
    """Modulo operation."""

    def __init__(self):
        super().__init__("mod")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.mod(args[0], args[1])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = np.remainder(args[0].to_numpy(), args[1].to_numpy())
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        from .unary import floor

        x, y = primals
        # For c = x % y = x - floor(x/y) * y
        # dc/dx = 1
        # dc/dy = -floor(x/y)
        cotangent_x = cotangent
        floor_div = floor(div(x, y))
        cotangent_y = mul(cotangent, mul(floor_div, -1))
        return [cotangent_x, cotangent_y]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        from .unary import floor

        x, y = primals
        dx, dy = tangents
        # For c = x % y = x - floor(x/y) * y
        # dc = dx - floor(x/y) * dy
        floor_div = floor(div(x, y))
        return sub(dx, mul(floor_div, dy))


# Create operation instances
_add_op = AddOp()
_mul_op = MulOp()
_sub_op = SubOp()
_div_op = DivOp()
_power_op = PowerOp()
_greater_equal_op = GreaterEqualOp()
_maximum_op = MaximumOp()
_minimum_op = MinimumOp()
_equal_op = EqualOp()
_not_equal_op = NotEqualOp()
_mod_op = ModOp()


[docs] def add(arg0, arg1) -> Array: """Element-wise addition of two arrays or array and scalar.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _add_op.forward(arg0, arg1)
[docs] def mul(arg0, arg1) -> Array: """Element-wise multiplication of two arrays or array and scalar.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _mul_op.forward(arg0, arg1)
[docs] def sub(arg0, arg1) -> Array: """Element-wise subtraction of two arrays or array and scalar.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _sub_op.forward(arg0, arg1)
[docs] def div(arg0, arg1) -> Array: """Element-wise division of two arrays or array and scalar.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _div_op.forward(arg0, arg1)
[docs] def floordiv(arg0, arg1) -> Array: """Element-wise floor division of two arrays or array and scalar. Floor division is implemented as floor(a / b) which rounds towards negative infinity, matching Python's // operator behavior. """ from ..ops.unary import floor arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) # Perform regular division then floor result = div(arg0, arg1) return floor(result)
# noqa: A001 - Intentionally shadowing built-in 'pow' for API consistency
[docs] def pow(arg0, arg1) -> Array: """Element-wise power operation (arg0^arg1).""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _power_op.forward(arg0, arg1)
[docs] def greater_equal(arg0: Array, arg1: Array) -> Array: """Element-wise greater than or equal to operation.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _greater_equal_op.forward(arg0, arg1)
[docs] def maximum(arg0, arg1) -> Array: """Element-wise maximum of two arrays or array and scalar.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _maximum_op.forward(arg0, arg1)
[docs] def minimum(arg0, arg1) -> Array: """Element-wise minimum of two arrays or array and scalar.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _minimum_op.forward(arg0, arg1)
[docs] def equal(arg0, arg1) -> Array: """Element-wise equality comparison.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _equal_op.forward(arg0, arg1)
[docs] def not_equal(arg0, arg1) -> Array: """Element-wise not-equal comparison.""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _not_equal_op.forward(arg0, arg1)
[docs] def mod(arg0, arg1) -> Array: """Element-wise modulo operation (arg0 % arg1).""" arg0 = _ensure_array(arg0) arg1 = _ensure_array(arg1) return _mod_op.forward(arg0, arg1)