Reduce#

Reduction operations.

nabla.ops.reduce.sum(arg, axes=None, keep_dims=False)[source]#

sum array elements over given axes.

nabla.ops.reduce.sum_batch_dims(arg, axes=None, keep_dims=False)[source]#

sum array elements over given batch dimension axes.

nabla.ops.reduce.mean(arg, axes=None, keep_dims=False)[source]#

Compute mean of array elements over given axes.

nabla.ops.reduce.max(arg, axes=None, keep_dims=False)[source]#

Find maximum array elements over given axes.

nabla.ops.reduce.argmax(arg, axes=None, keep_dims=False)[source]#

Find indices of maximum array elements over a given axis, matching JAX’s API.