argmax#
Signature#
nabla.argmax(arg: 'Array', axes: 'int | None', keep_dims: 'bool') -> 'Array'
Description#
Finds the indices of maximum array elements over a given axis.
This function returns the indices of the maximum values along an axis. If multiple occurrences of the maximum value exist, the index of the first occurrence is returned.
Parameters#
arg
(Array
): The input array.axes
(int | None, optional
): The axis along which to find the indices of the maximum values. If None (the default), the array is flattened before finding the index of the overall maximum value.keep_dims
(bool, optional
): If True, the axis which is reduced is left in the result as a dimension with size one. This is not supported whenaxes
is None. Defaults to False.
Returns#
Array
: An array ofint64
integers containing the indices of the maximum elements.
Examples#
>>> import nabla as nb
>>> x = nb.array([1, 5, 2, 5])
>>> nb.argmax(x)
Array(1, dtype=int64)
>>> y = nb.array([[1, 5, 2], [4, 3, 6]])
>>> nb.argmax(y, axes=1)
Array([1, 2], dtype=int64)
>>> nb.argmax(y, axes=0, keep_dims=True)
Array([[1, 0, 1]], dtype=int64)