permute_batch_dims#
Signature#
nabla.permute_batch_dims(input_array: 'Array', axes: 'tuple[int, ...]') -> 'Array'
Description#
Permute (reorder) the batch dimensions of an array.
This operation reorders the batch_dims of an Array according to the given axes, similar to how regular permute works on shape dimensions. The shape dimensions remain unchanged.
Parameters#
input_array
(Input array with batch dimensions to permute
): axes: Tuple specifying the new order of batch dimensions. All indices should be negative and form a permutation.
Returns#
Array with batch dimensions reordered according to axes
:
Examples#
>>> import nabla as nb
>>> # Array with batch_dims=(2, 3, 4) and shape=(5, 6)
>>> x = nb.ones((5, 6))
>>> x.batch_dims = (2, 3, 4) # Simulated for example
>>> y = permute_batch_dims(x, (-1, -3, -2)) # Reorder as (4, 2, 3)
>>> # Result has batch_dims=(4, 2, 3) and shape=(5, 6)