# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from collections.abc import Callable, Sequence
from typing import Any
from ..utils.grad_utils import (
create_ones_like_cotangent,
select_gradients_by_argnums,
validate_scalar_output,
)
from .vjp import vjp
[docs]
def value_and_grad(
fun: Callable | None = None,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence = (),
) -> Callable[..., Any]:
"""
Creates a function that evaluates both the value and gradient of fun.
This function uses VJP (Vector-Jacobian Product) directly with a cotangent
of ones_like(output) to compute gradients for scalar-valued functions.
This is simpler and more efficient than using jacrev/jacfwd for scalar outputs.
Parameters:
fun: Function to be differentiated. Should return a scalar.
argnums: Which positional argument(s) to differentiate with respect to (default 0).
has_aux: Whether fun returns (output, aux) pair (default False).
holomorphic: Whether fun is holomorphic - currently ignored (default False).
allow_int: Whether to allow integer inputs - currently ignored (default False).
reduce_axes: Axes to reduce over - currently ignored (default ()).
Returns:
A function that computes both the value and gradient of fun.
Examples:
Basic usage as a function call::
value_and_grad_fn = value_and_grad(my_loss)
value, grads = value_and_grad_fn(x)
Usage as a decorator::
@value_and_grad
def my_loss(x):
return x**2
value, grads = my_loss(3.0)
"""
# Handle being used as a decorator without arguments
if fun is None:
return lambda f: value_and_grad(
f,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)
def value_and_grad_fn(*args: Any) -> Any:
"""
The actual value_and_grad function that gets returned.
Validates that the function returns a scalar output, then computes
both the value and gradient using VJP with ones_like cotangent.
"""
# Compute VJP to get both output and pullback function
if has_aux:
vjp_result = vjp(fun, *args, has_aux=True)
output, vjp_fn, aux = vjp_result # type: ignore
else:
vjp_result = vjp(fun, *args, has_aux=False)
output, vjp_fn = vjp_result # type: ignore
# Validate scalar output
validate_scalar_output(output)
# Create cotangent of ones_like(output) and compute gradients
cotangent = create_ones_like_cotangent(output)
# VJP computes gradients for all inputs, select based on argnums
all_gradients = vjp_fn(cotangent)
# Use utility function to handle argnums selection
selected_gradients = select_gradients_by_argnums(all_gradients, args, argnums)
# Return based on has_aux - JAX returns ((value, aux), grad) when has_aux=True
if has_aux:
return (output, aux), selected_gradients
else:
return output, selected_gradients
return value_and_grad_fn
[docs]
def grad(
fun: Callable | None = None,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence = (),
mode: str = "reverse", # Kept for API compatibility but ignored
) -> Callable[..., Any]:
"""
Creates a function that evaluates the gradient of fun.
This is implemented as a special case of value_and_grad that only returns
the gradient part. Uses VJP directly for efficiency with scalar outputs.
Parameters:
fun: Function to be differentiated. Should return a scalar.
argnums: Which positional argument(s) to differentiate with respect to (default 0).
has_aux: Whether fun returns (output, aux) pair (default False).
holomorphic: Whether fun is holomorphic - currently ignored (default False).
allow_int: Whether to allow integer inputs - currently ignored (default False).
reduce_axes: Axes to reduce over - currently ignored (default ()).
mode: Kept for API compatibility but ignored (always uses reverse-mode VJP).
Returns:
A function that computes the gradient of fun.
Examples:
Basic usage as a function call::
grad_fn = grad(my_loss)
grads = grad_fn(x)
Usage as a decorator::
@grad
def my_loss(x):
return x**2
grads = my_loss(3.0) # Returns gradient, not function value
"""
# Handle decorator pattern: if fun is None, return a decorator
if fun is None:
return lambda f: grad(
f, argnums, has_aux, holomorphic, allow_int, reduce_axes, mode
)
# Get the value_and_grad function
value_and_grad_fn = value_and_grad(
fun=fun,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)
def grad_fn(*args: Any) -> Any:
"""
The actual gradient function that gets returned.
Just calls value_and_grad and returns only the gradient part.
"""
if has_aux:
(value, aux), gradients = value_and_grad_fn(*args)
return gradients, aux
else:
value, gradients = value_and_grad_fn(*args)
return gradients
return grad_fn