Reduction#
reduce_sum#
def reduce_sum(x: 'Tensor', *, axis: 'int | tuple[int, ...] | list[int] | None' = None, keepdims: 'bool' = False) -> 'Tensor':
reduce_max#
def reduce_max(x: 'Tensor', *, axis: 'int | tuple[int, ...] | list[int] | None' = None, keepdims: 'bool' = False) -> 'Tensor':
reduce_min#
def reduce_min(x: 'Tensor', *, axis: 'int | tuple[int, ...] | list[int] | None' = None, keepdims: 'bool' = False) -> 'Tensor':
mean#
def mean(x: 'Tensor', *, axis: 'int | tuple[int, ...] | list[int] | None' = None, keepdims: 'bool' = False) -> 'Tensor':
Compute arithmetic mean along specified axis/axes.
Implemented as sum(x) / product(shape[axes]) to correctly handle distributed sharding.
argmax#
def argmax(x: 'Tensor', axis: 'int' = -1, keepdims: 'bool' = False) -> 'Tensor':
argmin#
def argmin(x: 'Tensor', axis: 'int' = -1, keepdims: 'bool' = False) -> 'Tensor':
cumsum#
def cumsum(x: 'Tensor', axis: 'int' = -1, exclusive: 'bool' = False, reverse: 'bool' = False) -> 'Tensor':
reduce_sum_physical#
def reduce_sum_physical(x: 'Tensor', axis: 'int', keepdims: 'bool' = False) -> 'Tensor':
mean_physical#
def mean_physical(x: 'Tensor', axis: 'int', keepdims: 'bool' = False) -> 'Tensor':