Tensor#

The core n-dimensional array data structure in Nabla.

Tensor#

class Tensor(shape: 'Shape', dtype: 'DType' = float32, device: 'Device' = Device(type=cpu,id=0), materialize: 'bool' = False, name: 'str' = '', batch_dims: 'Shape' = ()) -> 'None':

Core tensor-like tensor class with automatic differentiation support.

Methods#

add_arguments#

def add_arguments(self, *arg_nodes: 'Tensor') -> 'None':

Add an arguments to this Tensor’s computation graph if traced.

astype#

def astype(self, dtype: 'DType') -> 'Tensor':

Convert tensor to a different data type.

Parameters

  • dtype – Target data type

Returns

– New Tensor with the specified data type

at#

def at(self, key, value):

Update tensor at specified indices/slices, returning new tensor.

backward#

def backward(self, grad: 'Tensor | None' = None, retain_graph: 'bool | None' = None, show_graph: 'bool' = False) -> 'None':

Compute gradients flowing into traced leaf inputs that influence this Tensor.

Parameters

  • grad – Optional cotangent tensor; defaults to ones for scalar outputs

  • retain_graph – If False, frees the computation graph. If True, it’s retained. If None (default), it’s retained only if inside a trace.

  • show_graph – If True, prints the compiled graph during backward pass

copy_from#

def copy_from(self, other: 'Tensor') -> 'None':

Copy data from another Tensor.

get_arguments#

def get_arguments(self) -> 'list[Tensor]':

Get list of argument Tensors.

impl_#

def impl_(self, value: 'Union[np.ndarray, MAXTensor] | None') -> 'None':

Set the implementation of this Tensor to a Numpy tensor or Tensor.

item#

def item(self) -> 'Union[float, int, bool]':

Get the single value of a scalar Tensor as a standard Python type.

Raises an error if the Tensor is not a scalar.

permute#

def permute(self, axes: 'tuple[int, ...]') -> 'Tensor':

Permute the dimensions of the tensor.

Parameters

  • axes – List of integers specifying the new order of dimensions

Returns

– Tensor with dimensions permuted according to the specified axes

Examples

Warning: Could not parse examples correctly.

realize#

def realize(self) -> 'None':

Force computation of this Tensor.

requires_grad_#

def requires_grad_(self, val: 'bool' = True) -> 'Tensor':

Opt into or out of gradient tracking for imperative workflows.

This is an in-place operation that returns self for method chaining. Similar to PyTorch’s requires_grad_() method.

reshape#

def reshape(self, shape: 'Shape') -> 'Tensor':

Change the shape of an tensor without changing its data.

Parameters

  • shape – New shape for the tensor

Returns

– Tensor with the new shape

Examples

Warning: Could not parse examples correctly.

set#

def set(self, key, value) -> 'Tensor':

Set values at specified indices/slices, returning a new tensor.

This is a functional operation that returns a new Tensor with the specified values updated, leaving the original Tensor unchanged.

Parameters

  • key – Index specification (int, slice, tuple of indices/slices, ellipsis)

  • value – Value(s) to set at the specified location

Returns

– New Tensor with updated values

Examples

Warning: Could not parse examples correctly.

set_maxpr#

def set_maxpr(self, fn: 'MaxprCallable') -> 'None':

Set the MAX PR function for this operation.

sum#

def sum(self, axes=None, keep_dims=False) -> 'Tensor':

Sum tensor elements over given axes.

Parameters

  • axes – Axis or axes along which to sum. Can be int, list of ints, or None (sum all)

  • keep_dims – If True, reduced axes are left as dimensions with size 1

Returns

– Tensor with the sum along the specified axes

Examples

Warning: Could not parse examples correctly.

to#

def to(self, device: 'Device') -> 'Tensor':

Move Tensor to specified device.

to_numpy#

def to_numpy(self) -> 'np.ndarray':

Get NumPy representation.

transpose#

def transpose(self, axes: 'tuple[int, ...]') -> 'Tensor':

Permute the dimensions of the tensor.

Parameters

  • axes – List of integers specifying the new order of dimensions

Returns

– Tensor with dimensions permuted according to the specified axes

Examples

Warning: Could not parse examples correctly.