Example 3: Graph Tracing — Under the Hood#

Nabla is a tracing-based framework. Every operation you write — nb.sin, nb.matmul, + — builds a computation graph behind the scenes. This graph is what transforms like grad, vmap, and compile operate on.

In this notebook you’ll learn to:

  1. Trace a function to see its computation graph

  2. See how grad transforms that graph (adding backward ops)

  3. See how vmap transforms it (adding batch dimensions)

  4. Understand why tracing matters for nb.compile

[11]:
import numpy as np

import nabla as nb
from nabla.core import trace  # low-level tracing API

print("Nabla graph tracing example")
Nabla graph tracing example

1. Tracing a Simple Function#

Let’s define a small function and trace it. Tracing runs the function with symbolic inputs, recording every operation into a graph — without actually computing any values.

[12]:
def f(x, y):
    """A simple function: f(x, y) = sin(x) * y + exp(x)."""
    return nb.sin(x) * y + nb.exp(x)
[13]:
# Create concrete inputs (their shapes/dtypes matter, not values)
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0], dtype=np.float32))
y = nb.Tensor.from_dlpack(np.array([3.0, 4.0], dtype=np.float32))

# Trace the function
traced_graph = trace(f, x, y)
print(traced_graph)
fn(
    %a1: f32[2],
    %a2: f32[2]
) {
  %v1: f32[2] = sin(%a1)
  %v2: f32[2] = mul(%v1, %a2)
  %v3: f32[2] = exp(%a1)
  %v4: f32[2] = add(%v2, %v3)
  return %v4
}

Reading the Trace Output#

The trace prints an IR (intermediate representation) of your function:

  • %a1, %a2 are the input arguments (your x and y)

  • %v1, %v2, … are intermediate values produced by operations

  • Each line shows: variable: type = operation(inputs)

  • The return statement shows which value is the final output

This is exactly the graph that gets compiled when you use @nb.compile.

2. How grad Transforms the Graph#

When you call nb.grad(f), Nabla doesn’t just run backpropagation at runtime. It transforms the graph itself — adding reverse-mode differentiation operations. Let’s see what that looks like:

[14]:
grad_f = nb.grad(f, argnums=0)  # gradient w.r.t. x

traced_grad = trace(grad_f, x, y)
print(traced_grad)
fn(
    %a1: f32[2],
    %a2: f32[2]
) {
  %v1: f32[2] = ones(device=Device(type=cpu,id=0), dtype=@float32, shape=(2,))
  %v2: f32[2] = exp(%a1)
  %v3: f32[2] = mul(%v1, %v2)
  %v4: f32[2] = mul(%v1, %a2)
  %v5: f32[2] = cos(%a1)
  %v6: f32[2] = mul(%v4, %v5)
  %v7: f32[2] = add(%v3, %v6)
  return %v7
}

Notice how the graph is now larger — it contains:

  1. The forward pass (same ops as before: sin, mul, exp, add)

  2. The backward pass (new ops that implement the chain rule)

Each backward op computes a partial derivative. Together, they propagate the gradient from the output back to the input x.

The key insight: grad is not magic — it’s a graph-to-graph transformation.

3. How vmap Transforms the Graph#

vmap (vectorized map) adds a batch dimension to every operation in the graph. Instead of processing one input at a time, the transformed function processes an entire batch in parallel:

[15]:
batched_f = nb.vmap(f)

# Create batched inputs: (batch=3, features=2)
x_batch = nb.Tensor.from_dlpack(
    np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
)
y_batch = nb.Tensor.from_dlpack(
    np.array([[0.5, 0.5], [1.0, 1.0], [2.0, 2.0]], dtype=np.float32)
)

traced_vmap = trace(batched_f, x_batch, y_batch)
print(traced_vmap)
fn(
    %a1: f32[3,2],
    %a2: f32[3,2]
) {
  %v1: f32[3 | 2] = incr_batch_dims(%a1)
  %v2: f32[3 | 2] = sin(%v1)
  %v3: f32[3 | 2] = incr_batch_dims(%a2)
  %v4: f32[3 | 2] = mul(%v2, %v3)
  %v5: f32[3 | 2] = exp(%v1)
  %v6: f32[3 | 2] = add(%v4, %v5)
  %v7: f32[3,2] = decr_batch_dims(%v6)
  return %v7
}

The operations are the same (sin, mul, exp, add), but now each tensor carries an extra batch dimension. The trace may show this as a leading dimension in the type signatures.

vmap doesn’t write a loop — it lifts every operation to work on batches natively.

4. Composing Transforms#

The real power is composing transforms. Since each transform is just a graph-to-graph function, you can stack them:

  • grad(vmap(f)) — gradient of a batched function

  • vmap(grad(f)) — per-sample gradients

  • jacrev(f) — full Jacobian via batched VJPs

Let’s trace a composed transform:

[16]:
# Per-sample gradients: vmap over grad
per_sample_grad = nb.vmap(nb.grad(f, argnums=0))

traced_composed = trace(per_sample_grad, x_batch, y_batch)
print(traced_composed)
fn(
    %a1: f32[3,2],
    %a2: f32[3,2]
) {
  %v1: f32[3 | 2] = ones(device=Device(type=cpu,id=0), dtype=@float32, shape=(3,2))
  %v2: f32[3 | 2] = incr_batch_dims(%a1)
  %v3: f32[3 | 2] = exp(%v2)
  %v4: f32[3 | 2] = mul(%v1, %v3)
  %v5: f32[3 | 2] = incr_batch_dims(%a2)
  %v6: f32[3 | 2] = mul(%v1, %v5)
  %v7: f32[3 | 2] = cos(%v2)
  %v8: f32[3 | 2] = mul(%v6, %v7)
  %v9: f32[3 | 2] = add(%v4, %v8)
  %v10: f32[3,2] = decr_batch_dims(%v9)
  return %v10
}

This graph contains both the backward ops from grad and the batching from vmap — composed automatically. Each sample in the batch gets its own gradient.

5. Why Tracing Matters for nb.compile#

When you decorate a function with @nb.compile, Nabla:

  1. Traces the function (just like we did above)

  2. Optimizes the graph (fusing ops, eliminating redundancy)

  3. Compiles it to run on the target hardware (CPU/GPU)

The trace is the bridge between your Python code and efficient compiled execution. Let’s verify — a compiled function produces the same results:

[17]:
compiled_f = nb.compile(f)

# Compare eager vs compiled
eager_result = f(x, y)
compiled_result = compiled_f(x, y)

print(f"Eager result:    {eager_result}")
print(f"Compiled result: {compiled_result}")
Eager result:    Tensor([ 5.2427 11.0262] : f32[2])
Compiled result: Tensor([ 5.2427 11.0262] : f32[2])

6. Tracing a More Realistic Function#

Let’s trace something closer to real ML — a tiny neural network layer with a loss function, then see what value_and_grad does to the graph:

[18]:
def simple_layer(params, x):
    """One linear layer + ReLU: f(x) = relu(x @ W + b)."""
    return nb.relu(x @ params["W"] + params["b"])

def loss_fn(params, x, target):
    """MSE loss on the layer output."""
    pred = simple_layer(params, x)
    return nb.mean((pred - target) ** 2)
[19]:
# Create small inputs
params = {
    "W": nb.Tensor.from_dlpack(np.random.randn(3, 2).astype(np.float32)),
    "b": nb.Tensor.from_dlpack(np.zeros(2, dtype=np.float32)),
}
x_in = nb.Tensor.from_dlpack(np.random.randn(4, 3).astype(np.float32))
target = nb.Tensor.from_dlpack(np.random.randn(4, 2).astype(np.float32))

# Trace the training step
train_step = nb.value_and_grad(loss_fn, argnums=0)
traced_train = trace(train_step, params, x_in, target)
print(traced_train)
fn(
    %a1: f32[3,2],
    %a2: f32[2],
    %a3: f32[4,3],
    %a4: f32[4,2]
) {
  %v1: f32[4,2] = matmul(%a3, %a1)
  %v2: f32[1,2] = unsqueeze(%a2, axis=0)
  %v3: f32[4,2] = broadcast_to(%v2, shape=(4,2))
  %v4: f32[4,2] = add(%v1, %v3)
  %v5: f32[4,2] = relu(%v4)
  %v6: f32[4,2] = sub(%v5, %a4)
  %v7: f32[1] = unsqueeze(?, axis=0)
  %v8: f32[1,1] = unsqueeze(%v7, axis=0)
  %v9: f32[4,2] = broadcast_to(%v8, shape=(4,2))
  %v10: f32[4,2] = pow(%v6, %v9)
  %v11: f32[4,1] = reduce_sum(%v10, axis=1, keepdims=True)
  %v12: f32[1,1] = reduce_sum(%v11, axis=0, keepdims=True)
  %v13: f32[1] = reshape(%v12, shape=(1,))
  %v14: f32[] = squeeze(%v13, axis=0)
  %v15: f32[] = div(%v14, ?)
  %v16: f32[3,4] = swap_axes(%a3, axis1=-2, axis2=-1)
  %v17: f32[1] = unsqueeze(?, axis=0)
  %v18: f32[1,1] = unsqueeze(%v17, axis=0)
  %v19: f32[4,2] = broadcast_to(%v18, shape=(4,2))
  %v20: bool[4,2] = greater(%v4, %v19)
  %v21: f32[] = ones(device=Device(type=cpu,id=0), dtype=@float32, shape=())
  %v22: f32[] = div(%v21, ?)
  %v23: f32[1] = unsqueeze(%v22, axis=0)
  %v24: f32[1,1] = reshape(%v23, shape=(1,1))
  %v25: f32[4,1] = broadcast_to(%v24, shape=(4,1))
  %v26: f32[4,2] = broadcast_to(%v25, shape=(4,2))
  %v27: f32[1] = unsqueeze(?, axis=0)
  %v28: f32[1,1] = unsqueeze(%v27, axis=0)
  %v29: f32[4,2] = broadcast_to(%v28, shape=(4,2))
  %v30: f32[4,2] = sub(%v9, %v29)
  %v31: f32[4,2] = pow(%v6, %v30)
  %v32: f32[4,2] = mul(%v9, %v31)
  %v33: f32[4,2] = mul(%v26, %v32)
  %v34: f32[4,2] = zeros(device=Device(type=cpu,id=0), dtype=@float32, shape=(4,2))
  %v35: f32[4,2] = where(%v20, %v33, %v34)
  %v36: f32[3,2] = matmul(%v16, %v35)
  %v37: f32[1,2] = reduce_sum(%v35, axis=0, keepdims=True)
  %v38: f32[2] = squeeze(%v37, axis=0)
  return (%v15, %v36, %v38)
}

You can see the full forward+backward pass for a training step. This is exactly the graph that @nb.compile would optimize and execute efficiently.

The operations flow: matmul add relu subtract square mean (forward), then the reverse chain computes gradients for W and b.

Summary#

Concept

What it does

trace(fn, *args)

Captures the computation graph without executing it

grad

Graph → graph: adds backward (chain rule) operations

vmap

Graph → graph: adds batch dimensions to all ops

compile

Traces → optimizes → compiles the graph for hardware

Composition

Transforms stack: vmap(grad(f)) works automatically

Understanding tracing helps you reason about what Nabla does under the hood, debug unexpected behavior, and write more efficient code.