Example 6: Pipeline Parallelism (GPipe)#

Pipeline parallelism splits a model into sequential stages across devices. Instead of running each input through all stages sequentially, we overlap computation by feeding new micro-batches into the pipeline while earlier ones are still being processed at later stages.

Time →  t0    t1    t2    t3    t4    t5    t6
Stage 0  [x0]  [x1]  [x2]  [x3]   ·     ·     ·
Stage 1   ·    [x0]  [x1]  [x2]  [x3]   ·     ·
Stage 2   ·     ·    [x0]  [x1]  [x2]  [x3]   ·
Stage 3   ·     ·     ·    [x0]  [x1]  [x2]  [x3]
                                  ↑ results start emerging

In this example we’ll:

  1. Shard an MLP across 4 pipeline stages

  2. Use ppermute for stage-to-stage communication

  3. Compute gradients through the full pipeline

[10]:
import numpy as np

import nabla as nb
from nabla import ops
from nabla.core.sharding import DeviceMesh, DimSpec, PartitionSpec as P
from nabla.ops import communication
from nabla.transforms import vmap

print("Nabla Pipeline Parallelism example")
Nabla Pipeline Parallelism example

1. Setup: Model and Device Mesh#

We’ll use a simple 4-layer MLP, with each layer on a separate pipeline stage. A DeviceMesh defines the logical layout of devices:

[11]:
# Pipeline configuration
STAGES = 4           # Number of pipeline stages (= devices)
MICRO_BATCHES = 8    # Number of micro-batches to stream through
MICRO_BATCH_SIZE = 4 # Samples per micro-batch
DIM = 16             # Hidden dimension

# Create a 1D device mesh for pipeline parallelism
mesh = DeviceMesh("pp", (STAGES,), ("stage",))
print(f"Device mesh: {STAGES} stages")
print(f"Pipeline: {MICRO_BATCHES} micro-batches × {MICRO_BATCH_SIZE} samples")
Device mesh: 4 stages
Pipeline: 8 micro-batches × 4 samples

2. Pipeline Primitives#

Each pipeline stage applies one linear layer followed by ReLU. We define three small helpers:

Function

Purpose

stage_compute

Apply one stage’s weight + bias → ReLU

pipeline_step

One tick: compute → shift → inject next micro-batch

pipeline_loop

Iterate steps, collecting outputs

[12]:
def stage_compute(x, w, b):
    """One pipeline stage: linear + ReLU."""
    return ops.relu(ops.matmul(x, w) + b)

pipeline_step is the core of GPipe: after computing all stages in parallel, ppermute shifts outputs one stage forward (stage 0→1, 1→2, …). The last stage’s result is extracted via a mask, and the fresh micro-batch is injected into stage 0:

[13]:
def pipeline_step(
    current_state, fresh_input, weight_stack, bias_stack, mask_0, step_fn, perm
):
    """Single GPipe step: compute -> shift -> extract result -> inject input."""
    computed = step_fn(current_state, weight_stack, bias_stack)
    shifted = communication.ppermute(computed, perm)
    # Extract the final stage's output (mask selects stage 0 after the shift)
    res_part = ops.where(mask_0, shifted, ops.zeros_like(shifted))
    result = ops.reduce_sum(res_part, axis=0)
    # Inject the fresh micro-batch at stage 0, pass shifted activations elsewhere
    next_state = ops.where(mask_0, fresh_input, shifted)
    return next_state, result

The pipeline loop feeds MICRO_BATCHES + STAGES ticks through the pipeline. During the first STAGES - 1 ticks the pipeline is “filling up”; results start emerging at tick STAGES:

[14]:
def pipeline_loop(
    padded_inputs, weight_stack, bias_stack, current_state, mask_0,
    step_fn, perm, total_steps,
):
    """Stream micro-batches through the pipeline for `total_steps` ticks."""
    results = []
    for t in range(total_steps):
        start_idx = (t, 0, 0)
        slice_size = (1, MICRO_BATCH_SIZE, DIM)
        fraction = ops.slice_tensor(padded_inputs, start=start_idx, size=slice_size)
        fresh = ops.squeeze(fraction, axis=0)

        current_state, res = pipeline_step(
            current_state, fresh, weight_stack, bias_stack, mask_0, step_fn, perm
        )
        results.append(res)

    return ops.stack(results, axis=0), current_state

3. Shard Data Across Stages#

Each weight matrix lives on one stage. We use ops.shard with a PartitionSpec to place the first dimension on the "stage" mesh axis. We also need zero-padded inputs (the pipeline needs STAGES empty ticks to fill up) and a boolean mask that selects stage 0:

[15]:
np.random.seed(42)

# Random weights (one per stage), inputs, and targets
w_np = np.random.randn(STAGES, DIM, DIM).astype(np.float32)
b_np = np.random.randn(STAGES, DIM).astype(np.float32)
x_np = np.random.randn(MICRO_BATCHES, MICRO_BATCH_SIZE, DIM).astype(np.float32)
y_np = np.random.randn(MICRO_BATCHES, MICRO_BATCH_SIZE, DIM).astype(np.float32)

# Shard weights: first axis → "stage" mesh axis
w_spec = [DimSpec.from_raw(d) for d in P("stage", None, None)]
b_spec = [DimSpec.from_raw(d) for d in P("stage", None)]

w_sharded = ops.shard(nb.Tensor.from_dlpack(w_np), mesh, w_spec)
b_sharded = ops.shard(nb.Tensor.from_dlpack(b_np), mesh, b_spec)

# Pad inputs with STAGES zero-slices for pipeline warm-up
padding = np.zeros((STAGES, MICRO_BATCH_SIZE, DIM), dtype=np.float32)
x_padded = nb.Tensor.from_dlpack(np.concatenate([x_np, padding], axis=0))
y_nb = nb.Tensor.from_dlpack(y_np)

# Initial pipeline state: zeros on each stage
state_sharded = ops.shard(
    nb.zeros((STAGES, MICRO_BATCH_SIZE, DIM)), mesh, w_spec
)

# Stage-0 mask for injecting fresh inputs
mask_np = np.eye(STAGES, 1).reshape(STAGES, 1, 1).astype(bool)
mask_0 = ops.shard(nb.Tensor.from_dlpack(mask_np), mesh, w_spec)

nb.realize_all(w_sharded, b_sharded, state_sharded, mask_0)
print(f"Weights sharded: {w_sharded.shape}, Inputs padded: {x_padded.shape}")
Weights sharded: [Dim(4), Dim(16), Dim(16)], Inputs padded: [Dim(12), Dim(4), Dim(16)]

4. Communication & Vectorized Stages#

ppermute shifts tensors between devices according to a permutation list. For a 4-stage pipeline, stage i sends its output to stage i+1 (with wrap-around):

perm = [(0,1), (1,2), (2,3), (3,0)]

We then use vmap with spmd_axis_name="stage" to auto-vectorize stage_compute over the stage dimension — each stage computes with its own weight/bias slice:

[16]:
# Build the circular permutation for ppermute
idx = mesh.axis_names.index("stage")
size = mesh.shape[idx]
perm = [(i, (i + 1) % size) for i in range(size)]
print(f"ppermute permutation: {perm}")

# Vectorize stage_compute over the stage axis
step_fn = vmap(
    stage_compute,
    in_axes=(0, 0, 0),
    out_axes=0,
    spmd_axis_name="stage",
    mesh=mesh,
)
print("step_fn ready — each stage runs its own weights in parallel")
ppermute permutation: [(0, 1), (1, 2), (2, 3), (3, 0)]
step_fn ready — each stage runs its own weights in parallel

5. Define the Pipeline Loss#

The loss function runs the full pipeline loop, slices out the valid outputs (the first STAGES ticks produce incomplete results), and computes MSE against the targets:

[17]:
def pipeline_loss(inputs, weights, biases, state, mask, targets):
    """MSE loss through the full GPipe pipeline."""
    total_steps = MICRO_BATCHES + STAGES
    stream_outputs, _ = pipeline_loop(
        inputs, weights, biases, state, mask, step_fn, perm, total_steps
    )

    # Slice valid range [STAGES : STAGES + MICRO_BATCHES]
    indices = ops.arange(STAGES, STAGES + MICRO_BATCHES, dtype=nb.DType.int64)
    valid_preds = ops.gather(stream_outputs, indices, axis=0)

    # MSE loss
    diff = valid_preds - targets
    return ops.mean(diff * diff)

6. Compute Gradients Through the Pipeline#

nb.grad differentiates through the entire pipeline — including ppermute shifts and per-stage vmap — computing gradients for inputs, weights, and biases simultaneously:

[18]:
grad_fn = nb.grad(pipeline_loss, argnums=(0, 1, 2), realize=False)

x_grad, w_grad, b_grad = grad_fn(
    x_padded, w_sharded, b_sharded, state_sharded, mask_0, y_nb
)

# Materialize results as numpy arrays
x_grad_np, w_grad_np, b_grad_np = nb.Tensor.to_numpy_all(x_grad, w_grad, b_grad)
x_grad_np = x_grad_np[:MICRO_BATCHES]  # trim padding region

print(f"Input gradient shape:  {x_grad_np.shape}")
print(f"Weight gradient shape: {w_grad_np.shape}")
print(f"Bias gradient shape:   {b_grad_np.shape}")
print(f"Weight grad range:     [{w_grad_np.min():.4f}, {w_grad_np.max():.4f}]")
Input gradient shape:  (8, 4, 16)
Weight gradient shape: (4, 16, 16)
Bias gradient shape:   (4, 16)
Weight grad range:     [-294.8232, 1527.5933]

7. Verify Against JAX Reference#

As a sanity check, we run the same computation sequentially in JAX and compare gradients. The pipeline scheduling should not change the mathematical result — only how computation is distributed:

[19]:
try:
    import jax
    import jax.numpy as jnp
    jax.config.update("jax_enable_x64", False)

    def jax_ref(x, params_w, params_b, y):
        def apply(curr, w, b):
            return jax.nn.relu(curr @ w + b)

        preds = []
        for i in range(MICRO_BATCHES):
            a = x[i]
            for w, b in zip(params_w, params_b, strict=False):
                a = apply(a, w, b)
            preds.append(a)
        preds = jnp.stack(preds)
        return jnp.mean((preds - y) ** 2)

    grad_ref_fn = jax.jit(jax.grad(jax_ref, argnums=(0, 1, 2)))
    x_grad_ref, w_grad_ref, b_grad_ref = grad_ref_fn(x_np, w_np, b_np, y_np)

    x_diff = np.max(np.abs(x_grad_np - x_grad_ref))
    w_diff = np.max(np.abs(w_grad_np - w_grad_ref))
    b_diff = np.max(np.abs(b_grad_np - b_grad_ref))

    print(f"Max input grad diff:  {x_diff:.6f}")
    print(f"Max weight grad diff: {w_diff:.6f}")
    print(f"Max bias grad diff:   {b_diff:.6f}")
    assert w_diff < 5e-4 and b_diff < 5e-4 and x_diff < 5e-4, "Gradient mismatch!"
    print("✅ Nabla pipeline gradients match JAX sequential reference")

except ImportError:
    print("JAX not installed — skipping reference comparison")
Max input grad diff:  0.000061
Max weight grad diff: 0.000122
Max bias grad diff:   0.000015
✅ Nabla pipeline gradients match JAX sequential reference

Key takeaways:

  • DeviceMesh + PartitionSpec place tensors on specific stages

  • ppermute handles inter-stage communication without explicit send/recv

  • vmap with spmd_axis_name vectorizes computation across stages

  • nb.grad differentiates through the entire sharded pipeline