Example 2: Automatic Differentiation#
Nabla provides a JAX-like functional autodiff system built on composable transforms. Every transform is a higher-order function: it takes a function and returns a new function that computes derivatives.
Transform |
Mode |
Computes |
|---|---|---|
|
Reverse |
Gradient of scalar-valued function |
|
Reverse |
(value, gradient) pair |
|
Forward |
Jacobian-vector product |
|
Reverse |
Vector-Jacobian product |
|
Reverse |
Full Jacobian matrix |
|
Forward |
Full Jacobian matrix |
[1]:
import numpy as np
import nabla as nb
print("Nabla autodiff example")
Nabla autodiff example
1. grad — Gradient of a Scalar Function#
nb.grad(fn) returns a function that computes the gradient of fn with respect to specified arguments (default: first argument).
[2]:
def f(x):
"""f(x) = x^3 + 2x^2 - 5x + 3, so f'(x) = 3x^2 + 4x - 5."""
return x ** 3 + 2.0 * x ** 2 - 5.0 * x + 3.0
df = nb.grad(f)
x = nb.Tensor.from_dlpack(np.array([2.0], dtype=np.float32))
grad_val = df(x)
print(f"f(x) = x^3 + 2x^2 - 5x + 3")
print(f"f'(x) = 3x^2 + 4x - 5")
print(f"f'(2.0) = 3*4 + 4*2 - 5 = {3*4 + 4*2 - 5}")
print(f"Nabla grad: {grad_val}")
f(x) = x^3 + 2x^2 - 5x + 3
f'(x) = 3x^2 + 4x - 5
f'(2.0) = 3*4 + 4*2 - 5 = 15
Nabla grad: Tensor([15.] : f32[1])
2. value_and_grad — Value and Gradient Together#
Often you need both the function value and its gradient. This is more efficient than calling f and grad(f) separately.
[3]:
def quadratic(x):
"""f(x) = sum(x^2), so grad = 2x."""
return nb.reduce_sum(x * x)
val_and_grad_fn = nb.value_and_grad(quadratic)
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
value, gradient = val_and_grad_fn(x)
print(f"x = [1, 2, 3]")
print(f"f(x) = sum(x^2) = {value}")
print(f"grad(f) = 2x = {gradient}")
x = [1, 2, 3]
f(x) = sum(x^2) = Tensor(14. : f32[])
grad(f) = 2x = Tensor([2. 4. 6.] : f32[3])
Multiple Arguments with argnums#
Use argnums to specify which arguments to differentiate with respect to.
[4]:
def weighted_sum(w, x):
"""f(w, x) = sum(w * x)."""
return nb.reduce_sum(w * x)
# Gradient w.r.t. first arg (w) only — default
grad_w = nb.grad(weighted_sum, argnums=0)
w = nb.Tensor.from_dlpack(np.array([1.0, 2.0], dtype=np.float32))
x = nb.Tensor.from_dlpack(np.array([3.0, 4.0], dtype=np.float32))
print(f"grad w.r.t. w: {grad_w(w, x)}")
print(f" (should be x = [3, 4])")
# Gradient w.r.t. second arg (x)
grad_x = nb.grad(weighted_sum, argnums=1)
print(f"grad w.r.t. x: {grad_x(w, x)}")
print(f" (should be w = [1, 2])")
# Gradient w.r.t. both — returns a tuple
grad_both = nb.grad(weighted_sum, argnums=(0, 1))
gw, gx = grad_both(w, x)
print(f"grad w.r.t. (w, x): ({gw}, {gx})")
grad w.r.t. w: Tensor([3. 4.] : f32[2])
(should be x = [3, 4])
grad w.r.t. x: Tensor([1. 2.] : f32[2])
(should be w = [1, 2])
grad w.r.t. (w, x): (Tensor([3. 4.] : f32[2]), Tensor([1. 2.] : f32[2]))
3. jvp — Forward-Mode (Jacobian-Vector Product)#
nb.jvp(fn, primals, tangents) computes:
The function output
fn(*primals)The directional derivative
J @ tangents(JVP)
This is efficient when the number of inputs is small (one forward pass per tangent direction).
[5]:
def g(x):
"""g(x) = [x0^2 + x1, x0 * x1]."""
r0 = nb.reshape(x[0] ** 2 + x[1], (1,))
r1 = nb.reshape(x[0] * x[1], (1,))
return nb.concatenate([r0, r1], axis=0)
x = nb.Tensor.from_dlpack(np.array([3.0, 2.0], dtype=np.float32))
v = nb.Tensor.from_dlpack(np.array([1.0, 0.0], dtype=np.float32))
output, jvp_val = nb.jvp(g, (x,), (v,))
print(f"g([3, 2]) = [3^2 + 2, 3*2] = {output}")
print(f"JVP with v=[1,0] (column 1 of Jacobian):")
print(f" J @ v = {jvp_val}")
print(f" Expected: [2*3, 2] = [6, 2]")
g([3, 2]) = [3^2 + 2, 3*2] = Tensor([11. 6.] : f32[2])
JVP with v=[1,0] (column 1 of Jacobian):
J @ v = Tensor([6. 2.] : f32[2])
Expected: [2*3, 2] = [6, 2]
[6]:
# Second column of the Jacobian
v2 = nb.Tensor.from_dlpack(np.array([0.0, 1.0], dtype=np.float32))
_, jvp_val2 = nb.jvp(g, (x,), (v2,))
print(f"JVP with v=[0,1] (column 2 of Jacobian):")
print(f" J @ v = {jvp_val2}")
print(f" Expected: [1, 3]")
JVP with v=[0,1] (column 2 of Jacobian):
J @ v = Tensor([1. 3.] : f32[2])
Expected: [1, 3]
4. vjp — Reverse-Mode (Vector-Jacobian Product)#
nb.vjp(fn, *primals) returns (output, vjp_fn) where vjp_fn(cotangent) gives the VJP = cotangent @ J.
This is efficient when the number of outputs is small (one backward pass per cotangent direction).
[7]:
def mat_fn(x):
"""f(x) = Ax, where A is 2x3 and x is (3,). Output is (2,)."""
A = nb.Tensor.from_dlpack(
np.array([[1.0, 0.0, 2.0], [0.0, 3.0, 1.0]], dtype=np.float32)
)
return A @ x # (2,3) @ (3,) = (2,)
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
output, vjp_fn = nb.vjp(mat_fn, x)
print(f"f(x) = Ax, A = [[1,0,2],[0,3,1]], x = [1,2,3]")
print(f"f(x) = {output}")
print(f" Expected: [1+0+6, 0+6+3] = [7, 9]")
f(x) = Ax, A = [[1,0,2],[0,3,1]], x = [1,2,3]
f(x) = Tensor([7. 9.] : f32[2])
Expected: [1+0+6, 0+6+3] = [7, 9]
Pulling Back Cotangent Vectors#
The VJP returns a pullback function vjp_fn. Given a cotangent vector \(v\) (same shape as the output), it computes \(v^\top J\) — the vector-Jacobian product. Each standard basis cotangent extracts one row of \(J^\top = A^\top\):
[8]:
# VJP with cotangent [1, 0] — gives first row of A^T
v1 = nb.Tensor.from_dlpack(np.array([1.0, 0.0], dtype=np.float32))
(vjp1,) = vjp_fn(v1)
print(f"VJP with v=[1,0]: {vjp1}")
print(f" Expected: A^T @ [1,0] = [1, 0, 2]")
# VJP with cotangent [0, 1] — gives second row of A^T
v2 = nb.Tensor.from_dlpack(np.array([0.0, 1.0], dtype=np.float32))
(vjp2,) = vjp_fn(v2)
print(f"VJP with v=[0,1]: {vjp2}")
print(f" Expected: A^T @ [0,1] = [0, 3, 1]")
VJP with v=[1,0]: Tensor([1. 0. 2.] : f32[3])
Expected: A^T @ [1,0] = [1, 0, 2]
VJP with v=[0,1]: Tensor([0. 3. 1.] : f32[3])
Expected: A^T @ [0,1] = [0, 3, 1]
5. jacrev — Full Jacobian via Reverse Mode#
nb.jacrev(fn) computes the full Jacobian matrix using reverse-mode autodiff (one backward pass per output element, batched via vmap).
[9]:
def h(x):
"""h(x) = Ax + sin(x), nonlinear vector function R^2 -> R^2."""
A = nb.Tensor.from_dlpack(
np.array([[2.0, -1.0], [1.0, 3.0]], dtype=np.float32)
)
return A @ x + nb.sin(x)
x = nb.Tensor.from_dlpack(np.array([1.0, 0.5], dtype=np.float32))
J = nb.jacrev(h)(x)
print("Jacobian via jacrev:")
print(J)
print("Expected: A + diag(cos(x))")
print(f" [[2+cos(1), -1 ],")
print(f" [1, 3+cos(0.5)]]")
print(f" ≈ [[{2+np.cos(1):.4f}, {-1:.4f}],")
print(f" [{1:.4f}, {3+np.cos(0.5):.4f}]]")
Jacobian via jacrev:
Tensor(
[[ 2.5403 -1. ]
[ 1. 3.8776]] : f32[2,2]
)
Expected: A + diag(cos(x))
[[2+cos(1), -1 ],
[1, 3+cos(0.5)]]
≈ [[2.5403, -1.0000],
[1.0000, 3.8776]]
6. jacfwd — Full Jacobian via Forward Mode#
nb.jacfwd(fn) computes the same Jacobian using forward-mode autodiff (one JVP per input element, batched via vmap). Prefer jacfwd when inputs are few and outputs are many.
[10]:
J_fwd = nb.jacfwd(h)(x)
print("Jacobian via jacfwd:")
print(J_fwd)
Jacobian via jacfwd:
Tensor(
[[ 2.5403 -1. ]
[ 1. 3.8776]] : f32[2,2]
)
When to use jacrev vs jacfwd#
Scenario |
Prefer |
|---|---|
Few outputs, many inputs |
|
Few inputs, many outputs |
|
Square Jacobian |
Either works |
Hessian (second derivative) |
Compose both! |
7. Hessians — Composing Transforms#
Because Nabla’s transforms are composable, you can compute Hessians (second-order derivatives) by nesting Jacobian transforms.
For a scalar function \(f: \mathbb{R}^n \to \mathbb{R}\), the Hessian \(H_{ij} = \frac{\partial^2 f}{\partial x_i \partial x_j}\) can be computed in multiple ways:
[11]:
def scalar_fn(x):
"""f(x) = x0^2 * x1 + x1^3, a polynomial with known Hessian."""
return x[0] ** 2 * x[1] + x[1] ** 3
x = nb.Tensor.from_dlpack(np.array([2.0, 3.0], dtype=np.float32))
print(f"f(x) = x0^2 * x1 + x1^3")
print(f"x = {x}")
print(f"f(x) = {scalar_fn(x)}")
# Analytical Hessian:
# df/dx0 = 2*x0*x1, df/dx1 = x0^2 + 3*x1^2
# H = [[2*x1, 2*x0], [2*x0, 6*x1]]
# At x=[2,3]: H = [[6, 4], [4, 18]]
print("\nAnalytical Hessian at x=[2,3]:")
print(" [[2*x1, 2*x0], [2*x0, 6*x1]] = [[6, 4], [4, 18]]")
f(x) = x0^2 * x1 + x1^3
x = Tensor([2. 3.] : f32[2])
f(x) = Tensor(39. : f32[])
Analytical Hessian at x=[2,3]:
[[2*x1, 2*x0], [2*x0, 6*x1]] = [[6, 4], [4, 18]]
Method 1: jacfwd(grad(f)) — Forward-over-Reverse#
The most common approach: first compute the gradient (reverse), then differentiate that gradient function again (forward) to get the full Hessian matrix.
[12]:
H1 = nb.jacfwd(nb.grad(scalar_fn))(x)
print("Method 1 — jacfwd(grad(f)):")
print(H1)
Method 1 — jacfwd(grad(f)):
Tensor(
[[ 6. 4.]
[ 4. 18.]] : f32[2,2]
)
Method 2: jacrev(grad(f)) — Reverse-over-Reverse#
Same idea, but use reverse mode for the outer Jacobian. Either combination works — choose based on the shape of the gradient.
[13]:
H2 = nb.jacrev(nb.grad(scalar_fn))(x)
print("Method 2 — jacrev(grad(f)):")
print(H2)
Method 2 — jacrev(grad(f)):
Tensor(
[[ 6. 4.]
[ 4. 18.]] : f32[2,2]
)
Method 3: jacrev(jacfwd(f)) — Full Jacobian Composition#
You can also compose two Jacobian transforms directly. Since jacfwd(f) returns the full Jacobian, wrapping it in jacrev differentiates each Jacobian entry — yielding the Hessian.
[14]:
H3 = nb.jacrev(nb.jacfwd(scalar_fn))(x)
print("Method 3 — jacrev(jacfwd(f)):")
print(H3)
Method 3 — jacrev(jacfwd(f)):
Tensor(
[[ 6. 4.]
[ 4. 18.]] : f32[2,2]
)
Method 4: jacfwd(jacrev(f)) — The Reverse Order#
Swapping the order also works. All four methods produce the identical Hessian matrix:
[15]:
H4 = nb.jacfwd(nb.jacrev(scalar_fn))(x)
print("Method 4 — jacfwd(jacrev(f)):")
print(H4)
print("\nAll four methods produce the same Hessian! ✅")
Method 4 — jacfwd(jacrev(f)):
Tensor(
[[ 6. 4.]
[ 4. 18.]] : f32[2,2]
)
All four methods produce the same Hessian! ✅
8. Gradient of a Multi-Variable Loss#
A more practical example: computing gradients for a simple regression loss.
[16]:
def linear_regression_loss(w, b, X, y):
"""MSE loss for linear regression: ||Xw + b - y||^2 / n."""
predictions = X @ w + b
residuals = predictions - y
return nb.mean(residuals * residuals)
# Create data
np.random.seed(42)
n_samples, n_features = 50, 3
X = nb.Tensor.from_dlpack(np.random.randn(n_samples, n_features).astype(np.float32))
w_true = nb.Tensor.from_dlpack(np.array([[2.0], [-1.0], [0.5]], dtype=np.float32))
y = X @ w_true + 0.1 * nb.gaussian((n_samples, 1))
# Initialize weights
w = nb.zeros((n_features, 1))
b = nb.zeros((1,))
# Compute gradients
grad_fn = nb.grad(linear_regression_loss, argnums=(0, 1))
dw, db = grad_fn(w, b, X, y)
print(f"Gradient w.r.t. weights (shape {dw.shape}):")
print(dw)
print(f"\nGradient w.r.t. bias (shape {db.shape}):")
print(db)
Gradient w.r.t. weights (shape [Dim(3), Dim(1)]):
Tensor(
[[-2.3079]
[ 2.5454]
[-0.4476]] : f32[3,1]
)
Gradient w.r.t. bias (shape [Dim(1)]):
Tensor([0.0037] : f32[1])
9. A Simple Gradient Descent#
Using value_and_grad in a training loop:
[17]:
w = nb.zeros((n_features, 1))
b = nb.zeros((1,))
lr = 0.1
vg_fn = nb.value_and_grad(linear_regression_loss, argnums=(0, 1))
print(f"{'Step':<6} {'Loss':<12}")
print("-" * 20)
for step in range(10):
loss, (dw, db) = vg_fn(w, b, X, y)
w = w - lr * dw
b = b - lr * db
if (step + 1) % 2 == 0:
print(f"{step + 1:<6} {loss.item():<12.6f}")
print(f"\nLearned weights: {w}")
print(f"True weights: {w_true}")
Step Loss
--------------------
2 2.658563
4 1.408227
6 0.797857
8 0.478955
10 0.301070
Learned weights: Tensor(
[[ 1.3725]
[-1.022 ]
[ 0.3079]] : f32[3,1]
)
True weights: Tensor(
[[ 2. ]
[-1. ]
[ 0.5]] : f32[3,1]
)
Summary#
Transform |
Input |
Output |
Best for |
|---|---|---|---|
|
Scalar fn |
Gradient vector |
Training losses |
|
Scalar fn |
(value, gradient) |
Training loops |
|
Any fn |
(output, J·v) |
Few inputs |
|
Any fn |
(output, vjp_fn) |
Few outputs |
|
Any fn |
Full Jacobian |
Few outputs |
|
Any fn |
Full Jacobian |
Few inputs |
Compose! |
— |
Hessians, etc. |
Higher-order derivatives |