Operation#

Base operation classes for a clean OOP design.

class nabla.ops.operation.Operation(name)[source]#

Bases: ABC

Abstract base class for all operations.

__init__(name)[source]#
abstract forward(*args)[source]#

Forward pass - creates the result Array.

abstract compute_output_shape(*input_shapes)[source]#

Compute the output shape given input shapes.

abstract maxpr(args, output)[source]#

MAX graph computation.

abstract eagerxpr(args, output)[source]#

Eager computation using NumPy.

abstract vjp_rule(primals, cotangent, output)[source]#

Vector-Jacobian product rule for reverse-mode autodiff.

abstract jvp_rule(primals, tangents, output)[source]#

Jacobian-vector product rule for forward-mode autodiff.

custom_kernel_path()[source]#

Optional: path to custom kernel implementation.

class nabla.ops.operation.UnaryOperation(name)[source]#

Bases: Operation

Base class for unary operations.

forward(*args)[source]#

Forward pass for unary operations.

compute_output_shape(*input_shapes)[source]#

Default: output shape same as input shape.

compute_output_dtype(arg)[source]#

Default: output dtype same as input dtype.

compute_output_batch_dims(input_batch_dims)[source]#

Default: output batch dims same as input batch dims.

nabla.ops.operation.move_to_best_device(*args)[source]#

Move all arrays to the best available device.

class nabla.ops.operation.BinaryOperation(name)[source]#

Bases: Operation

Base class for binary operations.

forward(*args)[source]#

Forward pass for binary operations.

compute_output_shape(*input_shapes)[source]#

Compute broadcasted output shape.

compute_output_dtype(arg1, arg2)[source]#

Default: output dtype same as first input dtype.

compute_output_batch_dims(*input_batch_dims)[source]#

Default: output batch dims same as input batch dims.

class nabla.ops.operation.ReductionOperation(name, axes=None, keep_dims=False)[source]#

Bases: UnaryOperation

Base class for reduction operations.

__init__(name, axes=None, keep_dims=False)[source]#
compute_output_shape(*input_shapes)[source]#

Compute output shape for reduction.

compute_output_batch_dims(*input_batch_dims)[source]#

Compute output batch dims for reduction.

class nabla.ops.operation.ViewOperation(name)[source]#

Bases: UnaryOperation

Base class for view operations (reshape, transpose, etc.).

__init__(name)[source]#
compute_output_batch_dims(*input_batch_dims)[source]#

Default: output batch dims same as input batch dims.