Example 8: Pipeline Parallel Inference#
This example demonstrates inference-only pipeline execution (no gradients). Compared to training (Examples 6–7), inference is simpler:
No bias terms (to keep the example minimal)
No loss function or backward pass
Results are compared against sequential NumPy to verify correctness
This pattern is useful for serving large models that don’t fit on a single device.
[1]:
import numpy as np
import nabla as nb
from nabla import ops
from nabla.core.sharding import DeviceMesh, DimSpec, PartitionSpec as P
from nabla.ops import communication
from nabla.transforms import vmap
# Pipeline configuration
STAGES = 4
MICRO_BATCHES = 8
MICRO_BATCH_SIZE = 4
DIM = 16
print("Nabla Pipeline Inference example")
Nabla Pipeline Inference example
1. Inference Pipeline Primitives#
For inference we only need weights (no biases here). The stage_compute applies a single linear layer + ReLU per stage:
[2]:
def stage_compute(x, w):
return ops.relu(ops.matmul(x, w))
def pipeline_step(current_state, fresh_input, weight_stack, mask_0, step_fn, perm):
"""Single GPipe step: compute -> shift -> extract -> inject."""
computed = step_fn(current_state, weight_stack)
shifted = communication.ppermute(computed, perm)
res_part = ops.where(mask_0, shifted, ops.zeros_like(shifted))
result = ops.reduce_sum(res_part, axis=0)
next_state = ops.where(mask_0, fresh_input, shifted)
return next_state, result
def pipeline_inference_loop(
padded_inputs, weight_stack, current_state, mask_0, step_fn, perm, total_steps
):
results = []
for t in range(total_steps):
start_idx = (t, 0, 0)
slice_size = (1, MICRO_BATCH_SIZE, DIM)
fraction = ops.slice_tensor(padded_inputs, start=start_idx, size=slice_size)
fresh = ops.squeeze(fraction, axis=0)
current_state, res = pipeline_step(
current_state, fresh, weight_stack, mask_0, step_fn, perm
)
results.append(res)
return ops.stack(results, axis=0), current_state
2. Shard Weights and Prepare Inputs#
Same pattern as Example 6 — shard the weight stack across pipeline stages:
[3]:
mesh = DeviceMesh("pp", (STAGES,), ("stage",))
np.random.seed(42)
w_np = np.random.randn(STAGES, DIM, DIM).astype(np.float32)
x_np = np.random.randn(MICRO_BATCHES, MICRO_BATCH_SIZE, DIM).astype(np.float32)
# Shard weights across stages
w_spec = [DimSpec.from_raw(d) for d in P("stage", None, None)]
w_sharded = ops.shard(nb.Tensor.from_dlpack(w_np), mesh, w_spec).realize()
# Pad inputs for pipeline warm-up
padding = np.zeros((STAGES, MICRO_BATCH_SIZE, DIM), dtype=np.float32)
x_padded = nb.Tensor.from_dlpack(np.concatenate([x_np, padding], axis=0))
# Initial state and stage-0 mask
state_sharded = ops.shard(
nb.Tensor.from_dlpack(np.zeros((STAGES, MICRO_BATCH_SIZE, DIM), dtype=np.float32)),
mesh, w_spec
).realize()
mask_np = np.eye(STAGES, 1).reshape(STAGES, 1, 1).astype(bool)
mask_0 = ops.shard(nb.Tensor.from_dlpack(mask_np), mesh, w_spec).realize()
print(f"Mesh: {mesh}")
print(f"Weights: {w_sharded.shape}, Inputs: {x_padded.shape}")
Mesh: @pp = <["stage"=4]>
Weights: [Dim(4), Dim(16), Dim(16)], Inputs: [Dim(12), Dim(4), Dim(16)]
3. Run Inference Pipeline#
Set up the ppermute permutation, vectorize across stages with vmap, and run the full pipeline:
[4]:
# Communication setup
idx = mesh.axis_names.index("stage")
size = mesh.shape[idx]
perm = [(i, (i + 1) % size) for i in range(size)]
# Vectorize stage_compute over the stage axis
step_fn = vmap(
stage_compute, in_axes=(0, 0), out_axes=0, spmd_axis_name="stage", mesh=mesh
)
# Run the full inference pipeline
total_steps = MICRO_BATCHES + STAGES
results, _ = pipeline_inference_loop(
x_padded, w_sharded, state_sharded, mask_0, step_fn, perm, total_steps
)
# Extract valid predictions (skip warm-up ticks)
preds = results[STAGES : STAGES + MICRO_BATCHES]
preds_np = preds.to_numpy()
print(f"Predictions shape: {preds_np.shape}")
print(f"Output range: [{preds_np.min():.4f}, {preds_np.max():.4f}]")
Predictions shape: (8, 4, 16)
Output range: [0.0000, 417.7554]
4. Verify Against Sequential NumPy#
Run the same computation sequentially in NumPy to confirm the pipeline produces identical results:
[5]:
# Sequential NumPy reference
ref_outs = []
for i in range(MICRO_BATCHES):
act = x_np[i]
for s in range(STAGES):
act = np.maximum(act @ w_np[s], 0) # ReLU(x @ W)
ref_outs.append(act)
ref = np.stack(ref_outs)
diff = np.max(np.abs(preds_np - ref))
print(f"Max difference vs NumPy reference: {diff:.6f}")
assert diff < 1e-4, f"Mismatch: {diff}"
print("✅ Pipeline inference matches sequential computation")
Max difference vs NumPy reference: 0.000000
✅ Pipeline inference matches sequential computation
Key takeaways:
Pipeline inference uses the same
ppermute+vmappattern as trainingWithout gradients, we simply call the loop directly — no
nb.gradneededThe pipeline produces numerically identical results to sequential execution