argmax

argmax#

Signature#

nabla.argmax(arg: 'Array', axes: 'int | None' = None, keep_dims: 'bool' = False) -> 'Array'

Description#

Find indices of maximum array elements over a given axis, matching JAX’s API.