Operation
Base operation classes for a clean OOP design.
-
class nabla.ops.operation.Operation(name)[source]
Bases: ABC
Abstract base class for all operations.
-
__init__(name)[source]
-
abstract forward(*args)[source]
Forward pass - creates the result Array.
-
abstract compute_output_shape(*input_shapes)[source]
Compute the output shape given input shapes.
-
abstract maxpr(args, output)[source]
MAX graph computation.
-
abstract eagerxpr(args, output)[source]
Eager computation using NumPy.
-
abstract vjp_rule(primals, cotangent, output)[source]
Vector-Jacobian product rule for reverse-mode autodiff.
-
abstract jvp_rule(primals, tangents, output)[source]
Jacobian-vector product rule for forward-mode autodiff.
-
custom_kernel_path()[source]
Optional: path to custom kernel implementation.
-
class nabla.ops.operation.UnaryOperation(name)[source]
Bases: Operation
Base class for unary operations.
-
forward(*args)[source]
Forward pass for unary operations.
-
compute_output_shape(*input_shapes)[source]
Default: output shape same as input shape.
-
compute_output_dtype(arg)[source]
Default: output dtype same as input dtype.
-
compute_output_batch_dims(input_batch_dims)[source]
Default: output batch dims same as input batch dims.
-
nabla.ops.operation.move_to_best_device(*args)[source]
Move all arrays to the best available device.
-
class nabla.ops.operation.BinaryOperation(name)[source]
Bases: Operation
Base class for binary operations.
-
forward(*args)[source]
Forward pass for binary operations.
-
compute_output_shape(*input_shapes)[source]
Compute broadcasted output shape.
-
compute_output_dtype(arg1, arg2)[source]
Default: output dtype same as first input dtype.
-
compute_output_batch_dims(*input_batch_dims)[source]
Default: output batch dims same as input batch dims.
-
class nabla.ops.operation.ReductionOperation(name, axes=None, keep_dims=False)[source]
Bases: UnaryOperation
Base class for reduction operations.
-
__init__(name, axes=None, keep_dims=False)[source]
-
compute_output_shape(*input_shapes)[source]
Compute output shape for reduction.
-
compute_output_batch_dims(*input_batch_dims)[source]
Compute output batch dims for reduction.
-
class nabla.ops.operation.ViewOperation(name)[source]
Bases: UnaryOperation
Base class for view operations (reshape, transpose, etc.).
-
__init__(name)[source]
-
compute_output_batch_dims(*input_batch_dims)[source]
Default: output batch dims same as input batch dims.