Example 3: Graph Tracing — Under the Hood#
Nabla is a tracing-based framework. Every operation you write — nb.sin, nb.matmul, + — builds a computation graph behind the scenes. This graph is what transforms like grad, vmap, and compile operate on.
In this notebook you’ll learn to:
Trace a function to see its computation graph
See how
gradtransforms that graph (adding backward ops)See how
vmaptransforms it (adding batch dimensions)Understand why tracing matters for
nb.compile
[11]:
import numpy as np
import nabla as nb
from nabla.core import trace # low-level tracing API
print("Nabla graph tracing example")
Nabla graph tracing example
1. Tracing a Simple Function#
Let’s define a small function and trace it. Tracing runs the function with symbolic inputs, recording every operation into a graph — without actually computing any values.
[12]:
def f(x, y):
"""A simple function: f(x, y) = sin(x) * y + exp(x)."""
return nb.sin(x) * y + nb.exp(x)
[13]:
# Create concrete inputs (their shapes/dtypes matter, not values)
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0], dtype=np.float32))
y = nb.Tensor.from_dlpack(np.array([3.0, 4.0], dtype=np.float32))
# Trace the function
traced_graph = trace(f, x, y)
print(traced_graph)
fn(
%a1: f32[2],
%a2: f32[2]
) {
%v1: f32[2] = sin(%a1)
%v2: f32[2] = mul(%v1, %a2)
%v3: f32[2] = exp(%a1)
%v4: f32[2] = add(%v2, %v3)
return %v4
}
Reading the Trace Output#
The trace prints an IR (intermediate representation) of your function:
%a1,%a2are the input arguments (yourxandy)%v1,%v2, … are intermediate values produced by operationsEach line shows:
variable: type = operation(inputs)The
returnstatement shows which value is the final output
This is exactly the graph that gets compiled when you use @nb.compile.
2. How grad Transforms the Graph#
When you call nb.grad(f), Nabla doesn’t just run backpropagation at runtime. It transforms the graph itself — adding reverse-mode differentiation operations. Let’s see what that looks like:
[14]:
grad_f = nb.grad(f, argnums=0) # gradient w.r.t. x
traced_grad = trace(grad_f, x, y)
print(traced_grad)
fn(
%a1: f32[2],
%a2: f32[2]
) {
%v1: f32[2] = ones(device=Device(type=cpu,id=0), dtype=@float32, shape=(2,))
%v2: f32[2] = exp(%a1)
%v3: f32[2] = mul(%v1, %v2)
%v4: f32[2] = mul(%v1, %a2)
%v5: f32[2] = cos(%a1)
%v6: f32[2] = mul(%v4, %v5)
%v7: f32[2] = add(%v3, %v6)
return %v7
}
Notice how the graph is now larger — it contains:
The forward pass (same ops as before:
sin,mul,exp,add)The backward pass (new ops that implement the chain rule)
Each backward op computes a partial derivative. Together, they propagate the gradient from the output back to the input x.
The key insight: grad is not magic — it’s a graph-to-graph transformation.
3. How vmap Transforms the Graph#
vmap (vectorized map) adds a batch dimension to every operation in the graph. Instead of processing one input at a time, the transformed function processes an entire batch in parallel:
[15]:
batched_f = nb.vmap(f)
# Create batched inputs: (batch=3, features=2)
x_batch = nb.Tensor.from_dlpack(
np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
)
y_batch = nb.Tensor.from_dlpack(
np.array([[0.5, 0.5], [1.0, 1.0], [2.0, 2.0]], dtype=np.float32)
)
traced_vmap = trace(batched_f, x_batch, y_batch)
print(traced_vmap)
fn(
%a1: f32[3,2],
%a2: f32[3,2]
) {
%v1: f32[3 | 2] = incr_batch_dims(%a1)
%v2: f32[3 | 2] = sin(%v1)
%v3: f32[3 | 2] = incr_batch_dims(%a2)
%v4: f32[3 | 2] = mul(%v2, %v3)
%v5: f32[3 | 2] = exp(%v1)
%v6: f32[3 | 2] = add(%v4, %v5)
%v7: f32[3,2] = decr_batch_dims(%v6)
return %v7
}
The operations are the same (sin, mul, exp, add), but now each tensor carries an extra batch dimension. The trace may show this as a leading dimension in the type signatures.
vmap doesn’t write a loop — it lifts every operation to work on batches natively.
4. Composing Transforms#
The real power is composing transforms. Since each transform is just a graph-to-graph function, you can stack them:
grad(vmap(f))— gradient of a batched functionvmap(grad(f))— per-sample gradientsjacrev(f)— full Jacobian via batched VJPs
Let’s trace a composed transform:
[16]:
# Per-sample gradients: vmap over grad
per_sample_grad = nb.vmap(nb.grad(f, argnums=0))
traced_composed = trace(per_sample_grad, x_batch, y_batch)
print(traced_composed)
fn(
%a1: f32[3,2],
%a2: f32[3,2]
) {
%v1: f32[3 | 2] = ones(device=Device(type=cpu,id=0), dtype=@float32, shape=(3,2))
%v2: f32[3 | 2] = incr_batch_dims(%a1)
%v3: f32[3 | 2] = exp(%v2)
%v4: f32[3 | 2] = mul(%v1, %v3)
%v5: f32[3 | 2] = incr_batch_dims(%a2)
%v6: f32[3 | 2] = mul(%v1, %v5)
%v7: f32[3 | 2] = cos(%v2)
%v8: f32[3 | 2] = mul(%v6, %v7)
%v9: f32[3 | 2] = add(%v4, %v8)
%v10: f32[3,2] = decr_batch_dims(%v9)
return %v10
}
This graph contains both the backward ops from grad and the batching from vmap — composed automatically. Each sample in the batch gets its own gradient.
5. Why Tracing Matters for nb.compile#
When you decorate a function with @nb.compile, Nabla:
Traces the function (just like we did above)
Optimizes the graph (fusing ops, eliminating redundancy)
Compiles it to run on the target hardware (CPU/GPU)
The trace is the bridge between your Python code and efficient compiled execution. Let’s verify — a compiled function produces the same results:
[17]:
compiled_f = nb.compile(f)
# Compare eager vs compiled
eager_result = f(x, y)
compiled_result = compiled_f(x, y)
print(f"Eager result: {eager_result}")
print(f"Compiled result: {compiled_result}")
Eager result: Tensor([ 5.2427 11.0262] : f32[2])
Compiled result: Tensor([ 5.2427 11.0262] : f32[2])
6. Tracing a More Realistic Function#
Let’s trace something closer to real ML — a tiny neural network layer with a loss function, then see what value_and_grad does to the graph:
[18]:
def simple_layer(params, x):
"""One linear layer + ReLU: f(x) = relu(x @ W + b)."""
return nb.relu(x @ params["W"] + params["b"])
def loss_fn(params, x, target):
"""MSE loss on the layer output."""
pred = simple_layer(params, x)
return nb.mean((pred - target) ** 2)
[19]:
# Create small inputs
params = {
"W": nb.Tensor.from_dlpack(np.random.randn(3, 2).astype(np.float32)),
"b": nb.Tensor.from_dlpack(np.zeros(2, dtype=np.float32)),
}
x_in = nb.Tensor.from_dlpack(np.random.randn(4, 3).astype(np.float32))
target = nb.Tensor.from_dlpack(np.random.randn(4, 2).astype(np.float32))
# Trace the training step
train_step = nb.value_and_grad(loss_fn, argnums=0)
traced_train = trace(train_step, params, x_in, target)
print(traced_train)
fn(
%a1: f32[3,2],
%a2: f32[2],
%a3: f32[4,3],
%a4: f32[4,2]
) {
%v1: f32[4,2] = matmul(%a3, %a1)
%v2: f32[1,2] = unsqueeze(%a2, axis=0)
%v3: f32[4,2] = broadcast_to(%v2, shape=(4,2))
%v4: f32[4,2] = add(%v1, %v3)
%v5: f32[4,2] = relu(%v4)
%v6: f32[4,2] = sub(%v5, %a4)
%v7: f32[1] = unsqueeze(?, axis=0)
%v8: f32[1,1] = unsqueeze(%v7, axis=0)
%v9: f32[4,2] = broadcast_to(%v8, shape=(4,2))
%v10: f32[4,2] = pow(%v6, %v9)
%v11: f32[4,1] = reduce_sum(%v10, axis=1, keepdims=True)
%v12: f32[1,1] = reduce_sum(%v11, axis=0, keepdims=True)
%v13: f32[1] = reshape(%v12, shape=(1,))
%v14: f32[] = squeeze(%v13, axis=0)
%v15: f32[] = div(%v14, ?)
%v16: f32[3,4] = swap_axes(%a3, axis1=-2, axis2=-1)
%v17: f32[1] = unsqueeze(?, axis=0)
%v18: f32[1,1] = unsqueeze(%v17, axis=0)
%v19: f32[4,2] = broadcast_to(%v18, shape=(4,2))
%v20: bool[4,2] = greater(%v4, %v19)
%v21: f32[] = ones(device=Device(type=cpu,id=0), dtype=@float32, shape=())
%v22: f32[] = div(%v21, ?)
%v23: f32[1] = unsqueeze(%v22, axis=0)
%v24: f32[1,1] = reshape(%v23, shape=(1,1))
%v25: f32[4,1] = broadcast_to(%v24, shape=(4,1))
%v26: f32[4,2] = broadcast_to(%v25, shape=(4,2))
%v27: f32[1] = unsqueeze(?, axis=0)
%v28: f32[1,1] = unsqueeze(%v27, axis=0)
%v29: f32[4,2] = broadcast_to(%v28, shape=(4,2))
%v30: f32[4,2] = sub(%v9, %v29)
%v31: f32[4,2] = pow(%v6, %v30)
%v32: f32[4,2] = mul(%v9, %v31)
%v33: f32[4,2] = mul(%v26, %v32)
%v34: f32[4,2] = zeros(device=Device(type=cpu,id=0), dtype=@float32, shape=(4,2))
%v35: f32[4,2] = where(%v20, %v33, %v34)
%v36: f32[3,2] = matmul(%v16, %v35)
%v37: f32[1,2] = reduce_sum(%v35, axis=0, keepdims=True)
%v38: f32[2] = squeeze(%v37, axis=0)
return (%v15, %v36, %v38)
}
You can see the full forward+backward pass for a training step. This is exactly the graph that @nb.compile would optimize and execute efficiently.
The operations flow: matmul → add → relu → subtract → square → mean (forward), then the reverse chain computes gradients for W and b.
Summary#
Concept |
What it does |
|---|---|
|
Captures the computation graph without executing it |
|
Graph → graph: adds backward (chain rule) operations |
|
Graph → graph: adds batch dimensions to all ops |
|
Traces → optimizes → compiles the graph for hardware |
Composition |
Transforms stack: |
Understanding tracing helps you reason about what Nabla does under the hood, debug unexpected behavior, and write more efficient code.