Differentiation#
grad#
def grad(fun: 'Callable', argnums: 'int | tuple[int, ...]' = 0, create_graph: 'bool' = False, realize: 'bool' = True) -> 'Callable':
Return a function computing the gradient of fun (must return a scalar).
value_and_grad#
def value_and_grad(fun: 'Callable', argnums: 'int | tuple[int, ...]' = 0, create_graph: 'bool' = False, realize: 'bool' = True) -> 'Callable':
Return a function computing (value, grad) of fun.
vjp#
def vjp(fn: 'Callable[..., Any]', *primals: 'Any', has_aux: 'bool' = False) -> 'tuple[Any, Callable[..., tuple[Any, ...]]] | tuple[Any, Callable[..., tuple[Any, ...]], Any]':
Compute VJP of fn at primals. Returns (output, vjp_fn[, aux]).
jvp#
def jvp(fn: 'Callable[..., Any]', primals: 'tuple[Any, ...]', tangents: 'tuple[Any, ...]', *, has_aux: 'bool' = False) -> 'tuple[Any, Any] | tuple[Any, Any, Any]':
Compute JVP of fn at primals with tangents. Returns (out, tangent_out[, aux]).
jacrev#
def jacrev(fn: 'Callable[..., Any]', argnums: 'int | tuple[int, ...] | list[int] | None' = None, has_aux: 'bool' = False) -> 'Callable[..., Any]':
Compute Jacobian of fn via reverse-mode (one VJP per output element).
jacfwd#
def jacfwd(fn: 'Callable[..., Any]', argnums: 'int | tuple[int, ...] | list[int] | None' = None, has_aux: 'bool' = False) -> 'Callable[..., Any]':
Compute Jacobian of fn via forward-mode (one JVP per input element).