permute_batch_dims

permute_batch_dims#

Signature#

nabla.permute_batch_dims(input_array: 'Array', axes: 'tuple[int, ...]') -> 'Array'

Description#

Permute (reorder) the batch dimensions of an array.

This operation reorders the batch_dims of an Array according to the given axes, similar to how regular permute works on shape dimensions. The shape dimensions remain unchanged.

Parameters#

  • input_array (Input array with batch dimensions to permute): axes: Tuple specifying the new order of batch dimensions. All indices should be negative and form a permutation.

Returns#

  • Array with batch dimensions reordered according to axes:

Examples#

>>> import nabla as nb
    >>> # Array with batch_dims=(2, 3, 4) and shape=(5, 6)
    >>> x = nb.ones((5, 6))
    >>> x.batch_dims = (2, 3, 4)  # Simulated for example
    >>> y = permute_batch_dims(x, (-1, -3, -2))  # Reorder as (4, 2, 3)
    >>> # Result has batch_dims=(4, 2, 3) and shape=(5, 6)