Source code for nabla.ops.unary

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Unary operations for the Nabla framework."""

import numpy as np
from max.driver import Device
from max.dtype import DType
from max.graph import DeviceRef, TensorValue, ops

from ..core.array import Array
from .operation import UnaryOperation

# Public API
__all__ = [
    "negate",
    "cast",
    "sin",
    "cos",
    "tanh",
    "sigmoid",
    "abs",
    "floor",
    "logical_not",
    "incr_batch_dim_ctr",
    "decr_batch_dim_ctr",
    "relu",
    "log",
    "exp",
    "sqrt",
    "transfer_to",
]


class NegateOp(UnaryOperation):
    """Element-wise negation operation."""

    def __init__(self):
        super().__init__("neg")

    def maxpr(self, args: list[TensorValue], output: Array) -> None:
        output.tensor_value = ops.negate(args[0])

    def eagerxpr(self, args: list[Array], output: Array) -> None:
        np_result = -args[0].to_numpy()
        # Ensure result is an array, not a scalar
        if np.isscalar(np_result):
            np_result = np.array(np_result)
        output.impl_(np_result)

    def vjp_rule(
        self, primals: list[Array], cotangent: Array, output: Array
    ) -> list[Array]:
        return [negate(cotangent)]

    def jvp_rule(
        self, primals: list[Array], tangents: list[Array], output: Array
    ) -> Array:
        return negate(tangents[0])


[docs] def negate(arg: Array) -> Array: """Element-wise negation.""" return _negate_op.forward(arg)
class CastOp(UnaryOperation): """Type casting operation.""" def __init__(self, dtype: DType): super().__init__(f"convert_element_type[new_dtype={dtype}]") self.target_dtype = dtype # def compute_output_shape(self, *input_shapes: tuple) -> tuple: # """Compatible signature - output shape same as input shape.""" # if len(input_shapes) != 1: # raise ValueError( # f"Cast operation requires 1 input shape, got {len(input_shapes)}" # ) # return input_shapes[0] def compute_output_dtype(self, arg: Array) -> DType: return self.target_dtype # def forward(self, *args: Array) -> Array: # """Override forward to set dtype with compatible signature.""" # if len(args) != 1: # raise ValueError(f"Cast operation requires 1 argument, got {len(args)}") # return super().forward(*args) def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.cast(args[0], output.dtype) def eagerxpr(self, args: list[Array], output: Array) -> None: np_result = args[0].to_numpy().astype(DType.to_numpy(output.dtype)) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: return [cast(cotangent, primals[0].dtype)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: return cast(tangents[0], output.dtype)
[docs] def cast(arg: Array, dtype: DType) -> Array: """Cast array to different dtype.""" if not isinstance(dtype, DType): raise TypeError(f"Dtype must be an instance of DType, got {type(dtype)}") op = CastOp(dtype) return op.forward(arg)
class SinOp(UnaryOperation): """Element-wise sine operation.""" def __init__(self): super().__init__("sin") def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.sin(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: np_result = np.sin(args[0].to_numpy()) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import mul return [mul(cotangent, cos(primals[0]))] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import mul return mul(tangents[0], cos(primals[0]))
[docs] def sin(arg: Array, dtype: DType | None = None) -> Array: """Element-wise sine.""" res = _sin_op.forward(arg) if dtype: return cast(res, dtype) return res
class CosOp(UnaryOperation): """Element-wise cosine operation.""" def __init__(self): super().__init__("cos") def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.cos(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: np_result = np.cos(args[0].to_numpy()) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import mul return [negate(mul(cotangent, sin(primals[0])))] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import mul return negate(mul(tangents[0], sin(primals[0])))
[docs] def cos(arg: Array) -> Array: """Element-wise cosine.""" return _cos_op.forward(arg)
class TanhOp(UnaryOperation): """Element-wise hyperbolic tangent operation.""" def __init__(self): super().__init__("tanh") def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.tanh(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: np_result = np.tanh(args[0].to_numpy()) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import mul, sub from .creation import ones_like # d/dx tanh(x) = 1 - tanh²(x) = 1 - output² ones_like_output = ones_like(output) tanh_squared = mul(output, output) derivative = sub(ones_like_output, tanh_squared) return [mul(cotangent, derivative)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import mul, sub from .creation import ones_like # d/dx tanh(x) = 1 - tanh²(x) ones_like_output = ones_like(output) tanh_squared = mul(output, output) derivative = sub(ones_like_output, tanh_squared) return mul(tangents[0], derivative)
[docs] def tanh(arg: Array) -> Array: """Element-wise hyperbolic tangent.""" return _tanh_op.forward(arg)
class AbsOp(UnaryOperation): """Element-wise absolute value operation.""" def __init__(self): super().__init__("abs") def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.abs(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: np_result = np.abs(args[0].to_numpy()) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import mul # d/dx |x| = sign(x) = 1 if x > 0, -1 if x < 0, undefined at x = 0 # We use the convention that sign(0) = 0 # # Workaround: Use the fact that abs(x) = x * sign(x), so sign(x) = abs(x) / x # But we need to handle x = 0 case. # Alternative: use the identity that d/dx |x| = x / |x| for x != 0, and 0 for x = 0 x = primals[0] abs_x = output # This is |x| # For x != 0: sign = x / |x| # For x == 0: sign = 0 (we'll handle this by checking if abs_x is zero) # Check if we're at zero (abs_x is very small) # Use a small epsilon to avoid division by zero eps = 1e-12 abs_x_safe = abs_x + eps # Add small epsilon to avoid division by zero # sign = x / abs_x_safe from .binary import div sign = div(x, abs_x_safe) # For true zeros, the sign should be zero. Since abs(0) = 0, # and we added eps, the division 0/(0+eps) = 0, which is correct. return [mul(cotangent, sign)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import div, mul # d/dx |x| = sign(x) = x / |x| for x != 0, 0 for x = 0 x = primals[0] abs_x = output # This is |x| # Use same approach as VJP: x / (|x| + eps) eps = 1e-12 abs_x_safe = abs_x + eps sign = div(x, abs_x_safe) return mul(tangents[0], sign)
[docs] def abs(arg: Array) -> Array: """Element-wise absolute value.""" return _abs_op.forward(arg)
class FloorOp(UnaryOperation): """Element-wise floor operation.""" def __init__(self): super().__init__("floor") def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.floor(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: np_result = np.floor(args[0].to_numpy()) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .creation import zeros_like return [zeros_like(cotangent)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .creation import zeros_like return zeros_like(tangents[0])
[docs] def floor(arg: Array) -> Array: """Element-wise floor function.""" return _floor_op.forward(arg)
class LogicalNotOp(UnaryOperation): """Element-wise logical NOT operation for boolean arrays.""" def __init__(self): super().__init__("logical_not") def compute_output_dtype(self, arg: Array) -> DType: """Logical NOT always returns boolean dtype.""" return DType.bool def maxpr(self, args: list[TensorValue], output: Array) -> None: # Convert input to boolean if needed (due to scalar boolean workaround) input_tensor = args[0] if input_tensor.dtype != DType.bool: # Cast to boolean first input_tensor = ops.cast(input_tensor, DType.bool) # Use MAX's logical not operation output.tensor_value = ops.logical_not(input_tensor) def eagerxpr(self, args: list[Array], output: Array) -> None: import numpy as np np_result = np.logical_not(args[0].to_numpy()) # Ensure result is always a numpy array if np.isscalar(np_result): np_result = np.array(np_result) # WORKAROUND: MAX library bug with scalar boolean tensors # The MAX tensor library fails when creating scalar boolean tensors # due to a bug in the _view method (line 49 in tensor.py) if np_result.shape == () and np_result.dtype == bool: # Convert scalar boolean to 1D boolean array, create tensor # The output will appear as scalar but be stored as 1D internally np_result_1d = np.array([np_result.item()], dtype=bool) output.impl_(np_result_1d) # Override the shape to appear as scalar output.shape = () else: # Normal path for non-scalar boolean or any non-boolean results output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .creation import zeros_like return [zeros_like(cotangent)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .creation import zeros_like return zeros_like(tangents[0])
[docs] def logical_not(arg: Array) -> Array: """Element-wise logical NOT operation.""" return _logical_not_op.forward(arg)
class SigmoidOp(UnaryOperation): """Element-wise sigmoid operation.""" def __init__(self): super().__init__("sigmoid") def maxpr(self, args: list[TensorValue], output: Array) -> None: # Sigmoid = 1 / (1 + exp(-x)) # Use MAX's built-in sigmoid if available, otherwise construct from primitives output.tensor_value = ops.sigmoid(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: # Numerically stable sigmoid implementation x = args[0].to_numpy() # For positive values: 1 / (1 + exp(-x)) # For negative values: exp(x) / (1 + exp(x)) np_result = np.where( x >= 0, 1.0 / (1.0 + np.exp(-x)), np.exp(x) / (1.0 + np.exp(x)) ) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import mul, sub from .creation import ones_like # d/dx sigmoid(x) = sigmoid(x) * (1 - sigmoid(x)) = output * (1 - output) ones_like_output = ones_like(output) one_minus_output = sub(ones_like_output, output) derivative = mul(output, one_minus_output) return [mul(cotangent, derivative)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import mul, sub from .creation import ones_like # d/dx sigmoid(x) = sigmoid(x) * (1 - sigmoid(x)) ones_like_output = ones_like(output) one_minus_output = sub(ones_like_output, output) derivative = mul(output, one_minus_output) return mul(tangents[0], derivative)
[docs] def sigmoid(arg: Array) -> Array: """Element-wise sigmoid function.""" return _sigmoid_op.forward(arg)
class IncrBatchDimCtr(UnaryOperation): """Increment batch dimension counter for debugging.""" def __init__(self, arg_batch_dims: tuple[int, ...], arg_shape: tuple[int, ...]): super().__init__("incr_batch_dim_ctr") self.arg_batch_dims = arg_batch_dims self.arg_shape = arg_shape def compute_output_shape(self, *input_shapes: tuple) -> tuple: """Output shape is the same as input shape.""" if not self.arg_shape: raise ValueError( f"IncrBatchDimCtr requires a non-empty arg_shape, got {self.arg_shape}" ) return self.arg_shape[1:] def compute_output_batch_dims(self, *input_batch_dims): if not self.arg_shape: raise ValueError( f"IncrBatchDimCtr requires a non-empty arg_shape, got {self.arg_shape}" ) return self.arg_batch_dims + (self.arg_shape[0],) def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = args[0] def eagerxpr(self, args: list[Array], output: Array) -> None: output.impl_(args[0]._impl) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: return [decr_batch_dim_ctr(cotangent)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: return incr_batch_dim_ctr(tangents[0])
[docs] def incr_batch_dim_ctr(arg: Array) -> Array: """Increment batch dimension counter for debugging.""" return IncrBatchDimCtr(arg.batch_dims, arg.shape).forward(arg)
class DecrBatchDimCtr(UnaryOperation): """Decrement batch dimension counter for debugging.""" def __init__(self, arg_batch_dims: tuple[int, ...], arg_shape: tuple[int, ...]): super().__init__("decr_batch_dim_ctr") self.arg_batch_dims = arg_batch_dims self.arg_shape = arg_shape def compute_output_shape(self, *input_shapes: tuple) -> tuple: """Output shape is the same as input shape.""" if not self.arg_batch_dims: raise ValueError( f"DecrBatchDimCtr requires a non-empty arg_batch_dims, got {self.arg_batch_dims}" ) return (self.arg_batch_dims[-1],) + self.arg_shape def compute_output_batch_dims(self, *input_batch_dims): if not self.arg_batch_dims: raise ValueError( f"DecrBatchDimCtr requires a non-empty arg_batch_dims, got {self.arg_batch_dims}" ) return self.arg_batch_dims[:-1] def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = args[0] def eagerxpr(self, args: list[Array], output: Array) -> None: output.impl_(args[0]._impl) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: return [incr_batch_dim_ctr(cotangent)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: return decr_batch_dim_ctr(tangents[0])
[docs] def decr_batch_dim_ctr(arg: Array) -> Array: """Decrement batch dimension counter for debugging.""" return DecrBatchDimCtr(arg.batch_dims, arg.shape).forward(arg)
class ReLUOp(UnaryOperation): """Element-wise ReLU operation.""" def __init__(self): super().__init__("relu") def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.relu(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: np_result = np.maximum(0, args[0].to_numpy()) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import div, mul # ReLU derivative: 1 if x > 0, 0 if x <= 0 # Since output = max(0, x), we have: # - If x > 0: output = x, so derivative = 1 # - If x <= 0: output = 0, so derivative = 0 x = primals[0] # Use the fact that for ReLU: # - When x > 0: output = x, so output/x = 1 (derivative should be 1) # - When x <= 0: output = 0, so output/x = 0 (derivative should be 0) # Add small epsilon to avoid division by zero eps = 1e-12 x_abs = abs(x) # This should work since we fixed abs x_safe = x_abs + eps # Always positive, so x_safe = |x| + eps # For x > 0: output = x, x_safe = x + eps ≈ x, so output/x_safe ≈ 1 # For x <= 0: output = 0, x_safe = |x| + eps > 0, so output/x_safe = 0 derivative = div(output, x_safe) return [mul(cotangent, derivative)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import div, mul # ReLU derivative: 1 if x > 0, 0 if x <= 0 # Use same approach as VJP: output / (|x| + eps) x = primals[0] # Add small epsilon to avoid division by zero eps = 1e-12 x_abs = abs(x) # This should work since we fixed abs x_safe = x_abs + eps # Always positive, so x_safe = |x| + eps # For x > 0: output = x, x_safe = x + eps ≈ x, so output/x_safe ≈ 1 # For x <= 0: output = 0, x_safe = |x| + eps > 0, so output/x_safe = 0 derivative = div(output, x_safe) return mul(tangents[0], derivative)
[docs] def relu(arg: Array) -> Array: """Element-wise ReLU (Rectified Linear Unit) function.""" return _relu_op.forward(arg)
class LogOp(UnaryOperation): """Element-wise natural logarithm operation.""" def __init__(self): super().__init__("log") def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.log(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: input_array = args[0].to_numpy() epsilon = 1e-15 safe_input = np.maximum(input_array, epsilon) np_result = np.log(safe_input) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import div return [div(cotangent, primals[0])] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import div return div(tangents[0], primals[0])
[docs] def log(arg: Array) -> Array: """Element-wise natural logarithm.""" return _log_op.forward(arg)
class ExpOp(UnaryOperation): """Element-wise exponential operation.""" def __init__(self): super().__init__("exp") def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.exp(args[0]) def eagerxpr(self, args: list[Array], output: Array) -> None: np_result = np.exp(args[0].to_numpy()) # Ensure result is an array, not a scalar if np.isscalar(np_result): np_result = np.array(np_result) output.impl_(np_result) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: from .binary import mul # d/dx exp(x) = exp(x), and output = exp(x) return [mul(cotangent, output)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: from .binary import mul # d/dx exp(x) = exp(x) return mul(output, tangents[0])
[docs] def exp(arg: Array) -> Array: """Element-wise exponential function.""" return _exp_op.forward(arg)
[docs] def sqrt(arg: Array) -> Array: """Element-wise square root function. Implemented as pow(arg, 0.5) for compatibility with the automatic differentiation system. """ from .binary import pow as binary_pow from .creation import array # Create 0.5 as a scalar Array half = array(0.5, dtype=arg.dtype) return binary_pow(arg, half)
# class ZerosLikeOp(UnaryOperation): # """Create an array of zeros with the same shape and dtype as the input.""" # def __init__(self): # super().__init__("zeros_like") # def compute_output_shape(self, *input_shapes: tuple) -> tuple: # """Output shape is the same as input shape.""" # if len(input_shapes) != 1: # raise ValueError(f"ZerosLikeOp requires 1 input shape, got {len(input_shapes)}") # return input_shapes[0] # def compute_output_dtype(self, arg: Array) -> DType: # """Output dtype is the same as input dtype.""" # return arg.dtype # def maxpr(self, args: list[TensorValue], output: Array) -> None: # output.tensor_value = ops.shape_to_tensor(args[0].shape) # def eagerxpr(self, args: list[Array], output: Array) -> None: # np_result = np.zeros(args[0].shape, dtype=args[0].dtype.to_numpy()) # output.impl_(np_result) # def vjp_rule( # self, primals: list[Array], cotangent: Array, output: Array # ) -> list[Array]: # return [zeros_like(cotangent)] # def jvp_rule( # self, primals: list[Array], tangents: list[Array], output: Array # ) -> Array: # return zeros_like(tangents[0]) class TransferToOp(UnaryOperation): """Transfer operation to a different device.""" def __init__(self, arg_device: Device, target_device: Device): super().__init__(f"transfer_to[{target_device}]") self.arg_device = arg_device self.target_device = target_device def forward(self, *args: Array) -> Array: """Forward pass for unary operations.""" if len(args) != 1: raise ValueError(f"Unary operation requires 1 argument, got {len(args)}") arg = args[0] output_shape = self.compute_output_shape(arg.shape) output_batch_dims = self.compute_output_batch_dims(arg.batch_dims) output_dtype = self.compute_output_dtype(arg) res = Array( shape=output_shape, dtype=output_dtype, device=self.target_device, materialize=False, name=self.name, batch_dims=output_batch_dims, ) res.set_maxpr(self.maxpr) res.add_arguments(arg) res.vjp_rule = self.vjp_rule res.jvp_rule = self.jvp_rule res.custom_kernel_path = self.custom_kernel_path() if not res.stage_realization: self.eagerxpr([arg], res) return res def maxpr(self, args: list[TensorValue], output: Array) -> None: output.tensor_value = ops.transfer_to( args[0], DeviceRef.from_device(self.target_device) ) def eagerxpr(self, args: list[Array], output: Array) -> None: output.impl_(args[0]._impl) def vjp_rule( self, primals: list[Array], cotangent: Array, output: Array ) -> list[Array]: return [transfer_to(cotangent, self.arg_device)] def jvp_rule( self, primals: list[Array], tangents: list[Array], output: Array ) -> Array: return transfer_to(tangents[0], self.target_device)
[docs] def transfer_to(arg: Array, device: Device) -> Array: """Transfer an array to a different device.""" if not isinstance(device, Device): raise TypeError(f"Device must be an instance of Device, got {type(device)}") # if arg.device.id == device.id: # return arg return TransferToOp(arg.device, device).forward(arg)
# Add global instances _negate_op = NegateOp() _sin_op = SinOp() _cos_op = CosOp() _tanh_op = TanhOp() _abs_op = AbsOp() _floor_op = FloorOp() _logical_not_op = LogicalNotOp() _sigmoid_op = SigmoidOp() _log_op = LogOp() _exp_op = ExpOp() _relu_op = ReLUOp()