gather#
Signature#
nabla.gather(input_array: 'Array', indices: 'Array', axis: 'int') -> 'Array'
Description#
Selects elements from an input array using indices along a specified axis.
This function is analogous to numpy.take_along_axis
. It selects elements
from input_array
at the positions specified by indices
.
Parameters#
input_array
(Array
): The source array from which to gather values.indices
(Array
): The array of indices to gather. Must be an integer-typed array.axis
(int, optional
): The axis along which to gather. A negative value counts from the last dimension. Defaults to -1.
Returns#
Array
: A new array containing the elements ofinput_array
at the givenindices
.
Examples#
>>> import nabla as nb
>>> x = nb.array([[10, 20, 30], [40, 50, 60]])
>>> indices = nb.array([[0, 2], [1, 0]])
>>> # Gather along axis 1
>>> nb.gather(x, indices, axis=1)
Array([[10, 30],
[50, 40]], dtype=int32)
>>> # Gather along axis 0
>>> indices = nb.array([[0, 1, 0]])
>>> nb.gather(x, indices, axis=0)
Array([[10, 50, 30]], dtype=int32)