Example 8: Pipeline Parallel Inference#

This example demonstrates inference-only pipeline execution (no gradients). Compared to training (Examples 6–7), inference is simpler:

  • No bias terms (to keep the example minimal)

  • No loss function or backward pass

  • Results are compared against sequential NumPy to verify correctness

This pattern is useful for serving large models that don’t fit on a single device.

[1]:
import numpy as np

import nabla as nb
from nabla import ops
from nabla.core.sharding import DeviceMesh, DimSpec, PartitionSpec as P
from nabla.ops import communication
from nabla.transforms import vmap

# Pipeline configuration
STAGES = 4
MICRO_BATCHES = 8
MICRO_BATCH_SIZE = 4
DIM = 16

print("Nabla Pipeline Inference example")
Nabla Pipeline Inference example

1. Inference Pipeline Primitives#

For inference we only need weights (no biases here). The stage_compute applies a single linear layer + ReLU per stage:

[2]:
def stage_compute(x, w):
    return ops.relu(ops.matmul(x, w))


def pipeline_step(current_state, fresh_input, weight_stack, mask_0, step_fn, perm):
    """Single GPipe step: compute -> shift -> extract -> inject."""
    computed = step_fn(current_state, weight_stack)
    shifted = communication.ppermute(computed, perm)
    res_part = ops.where(mask_0, shifted, ops.zeros_like(shifted))
    result = ops.reduce_sum(res_part, axis=0)
    next_state = ops.where(mask_0, fresh_input, shifted)
    return next_state, result


def pipeline_inference_loop(
    padded_inputs, weight_stack, current_state, mask_0, step_fn, perm, total_steps
):
    results = []
    for t in range(total_steps):
        start_idx = (t, 0, 0)
        slice_size = (1, MICRO_BATCH_SIZE, DIM)
        fraction = ops.slice_tensor(padded_inputs, start=start_idx, size=slice_size)
        fresh = ops.squeeze(fraction, axis=0)

        current_state, res = pipeline_step(
            current_state, fresh, weight_stack, mask_0, step_fn, perm
        )
        results.append(res)

    return ops.stack(results, axis=0), current_state

2. Shard Weights and Prepare Inputs#

Same pattern as Example 6 — shard the weight stack across pipeline stages:

[3]:
mesh = DeviceMesh("pp", (STAGES,), ("stage",))
np.random.seed(42)

w_np = np.random.randn(STAGES, DIM, DIM).astype(np.float32)
x_np = np.random.randn(MICRO_BATCHES, MICRO_BATCH_SIZE, DIM).astype(np.float32)

# Shard weights across stages
w_spec = [DimSpec.from_raw(d) for d in P("stage", None, None)]
w_sharded = ops.shard(nb.Tensor.from_dlpack(w_np), mesh, w_spec).realize()

# Pad inputs for pipeline warm-up
padding = np.zeros((STAGES, MICRO_BATCH_SIZE, DIM), dtype=np.float32)
x_padded = nb.Tensor.from_dlpack(np.concatenate([x_np, padding], axis=0))

# Initial state and stage-0 mask
state_sharded = ops.shard(
    nb.Tensor.from_dlpack(np.zeros((STAGES, MICRO_BATCH_SIZE, DIM), dtype=np.float32)),
    mesh, w_spec
).realize()

mask_np = np.eye(STAGES, 1).reshape(STAGES, 1, 1).astype(bool)
mask_0 = ops.shard(nb.Tensor.from_dlpack(mask_np), mesh, w_spec).realize()

print(f"Mesh: {mesh}")
print(f"Weights: {w_sharded.shape}, Inputs: {x_padded.shape}")
Mesh: @pp = <["stage"=4]>
Weights: [Dim(4), Dim(16), Dim(16)], Inputs: [Dim(12), Dim(4), Dim(16)]

3. Run Inference Pipeline#

Set up the ppermute permutation, vectorize across stages with vmap, and run the full pipeline:

[4]:
# Communication setup
idx = mesh.axis_names.index("stage")
size = mesh.shape[idx]
perm = [(i, (i + 1) % size) for i in range(size)]

# Vectorize stage_compute over the stage axis
step_fn = vmap(
    stage_compute, in_axes=(0, 0), out_axes=0, spmd_axis_name="stage", mesh=mesh
)

# Run the full inference pipeline
total_steps = MICRO_BATCHES + STAGES
results, _ = pipeline_inference_loop(
    x_padded, w_sharded, state_sharded, mask_0, step_fn, perm, total_steps
)

# Extract valid predictions (skip warm-up ticks)
preds = results[STAGES : STAGES + MICRO_BATCHES]
preds_np = preds.to_numpy()
print(f"Predictions shape: {preds_np.shape}")
print(f"Output range: [{preds_np.min():.4f}, {preds_np.max():.4f}]")
Predictions shape: (8, 4, 16)
Output range: [0.0000, 417.7554]

4. Verify Against Sequential NumPy#

Run the same computation sequentially in NumPy to confirm the pipeline produces identical results:

[5]:
# Sequential NumPy reference
ref_outs = []
for i in range(MICRO_BATCHES):
    act = x_np[i]
    for s in range(STAGES):
        act = np.maximum(act @ w_np[s], 0)  # ReLU(x @ W)
    ref_outs.append(act)
ref = np.stack(ref_outs)

diff = np.max(np.abs(preds_np - ref))
print(f"Max difference vs NumPy reference: {diff:.6f}")
assert diff < 1e-4, f"Mismatch: {diff}"
print("✅ Pipeline inference matches sequential computation")
Max difference vs NumPy reference: 0.000000
✅ Pipeline inference matches sequential computation

Key takeaways:

  • Pipeline inference uses the same ppermute + vmap pattern as training

  • Without gradients, we simply call the loop directly — no nb.grad needed

  • The pipeline produces numerically identical results to sequential execution