softmax#

Signature#

nabla.softmax(arg: 'Array', axis: 'int') -> 'Array'

Description#

Computes the softmax function for an array.

The softmax function transforms a vector of real numbers into a probability distribution. Each element in the output is in the range (0, 1), and the elements along the specified axis sum to 1. It is calculated in a numerically stable way as exp(x - logsumexp(x)).

Parameters#

  • arg (Array): The input array.

  • axis (int, optional): The axis along which the softmax computation is performed. The default is -1, which is the last axis.

Returns#

  • Array: An array of the same shape as the input, containing the softmax probabilities.

Examples#

>>> import nabla as nb
>>> x = nb.array([1.0, 2.0, 3.0])
>>> nb.softmax(x)
Array([0.09003057, 0.24472848, 0.66524094], dtype=float32)

>>> logits = nb.array([[1, 2, 3], [1, 1, 1]])
>>> nb.softmax(logits, axis=1)
Array([[0.09003057, 0.24472848, 0.66524094],
       [0.33333334, 0.33333334, 0.33333334]], dtype=float32)