Example 7: 2D Parallel Training (Pipeline + Data Parallelism)#
This notebook extends pipeline parallelism (Example 6) by adding a data-parallel dimension, creating a 2D device mesh:
Pipeline stages →
Stage 0 Stage 1 Stage 2 Stage 3
Data DP 0 [w0] [w1] [w2] [w3] ← same weights, different data
Par. DP 1 [w0] [w1] [w2] [w3] ← same weights, different data
Key idea: Weights are sharded across pipeline stages and replicated across data-parallel replicas. Input batches are sharded across DP replicas.
We’ll:
Build a 2D
DeviceMesh("dp", "pp")Shard weights on
"pp", data on"dp"Use the same pipeline primitives from Example 6
Compute gradients with
nb.value_and_grad
[7]:
import numpy as np
import nabla as nb
from nabla import ops
from nabla.core.sharding import DeviceMesh, DimSpec
from nabla.ops import communication
from nabla.transforms import vmap
print("Nabla 2D Parallelism example")
Nabla 2D Parallelism example
1. Configuration and Device Mesh#
The 2D mesh has shape (DP_SIZE, PP_SIZE) with named axes "dp" and "pp". Each device is identified by a (dp_rank, pp_rank) pair:
[8]:
# 2D mesh dimensions
DP_SIZE = 2 # Data-parallel replicas
PP_SIZE = 4 # Pipeline stages
MICRO_BATCHES = 4
MICRO_BATCH_SIZE = 4
DIM = 16
mesh = DeviceMesh("2d", (DP_SIZE, PP_SIZE), ("dp", "pp"))
print(f"2D device mesh: {DP_SIZE} DP replicas × {PP_SIZE} PP stages = {DP_SIZE * PP_SIZE} devices")
2D device mesh: 2 DP replicas × 4 PP stages = 8 devices
2. Pipeline Primitives (same as Example 6)#
The stage compute, step, and loop functions are identical to Example 6. Only the sharding specification changes — the mesh now has two axes:
[9]:
def stage_compute(x, w, b):
"""One pipeline stage: linear + ReLU."""
return ops.relu(ops.matmul(x, w) + b)
def pipeline_step(current_state, fresh_input, weight_stack, bias_stack, mask_0, step_fn, perm):
"""Compute → shift → extract result → inject fresh input."""
computed = step_fn(current_state, weight_stack, bias_stack)
shifted = communication.ppermute(computed, perm)
res_part = ops.where(mask_0, shifted, ops.zeros_like(shifted))
result = ops.reduce_sum(res_part, axis=0)
next_state = ops.where(mask_0, fresh_input, shifted)
return next_state, result
def pipeline_loop(padded_inputs, weight_stack, bias_stack, current_state, mask_0, step_fn, perm, total_steps):
"""Stream micro-batches through the pipeline."""
results = []
for t in range(total_steps):
start_idx = (t, 0, 0)
slice_size = (1, MICRO_BATCH_SIZE, DIM)
fraction = ops.slice_tensor(padded_inputs, start=start_idx, size=slice_size)
fresh = ops.squeeze(fraction, axis=0)
current_state, res = pipeline_step(current_state, fresh, weight_stack, bias_stack, mask_0, step_fn, perm)
results.append(res)
return ops.stack(results, axis=0), current_state
3. Shard Data for 2D Parallelism#
The key difference from Example 6: we now specify two-axis sharding.
Tensor |
Sharding |
Meaning |
|---|---|---|
Weights |
|
Partitioned across pipeline stages, replicated across DP |
Biases |
|
Same as weights |
Inputs |
|
Replicated across PP, partitioned across DP |
State |
|
Partitioned on both axes |
[10]:
np.random.seed(42)
total_steps = MICRO_BATCHES + PP_SIZE
w_np = np.random.randn(PP_SIZE, DIM, DIM).astype(np.float32)
b_np = np.random.randn(PP_SIZE, DIM).astype(np.float32)
x_np = np.random.randn(MICRO_BATCHES, MICRO_BATCH_SIZE, DIM).astype(np.float32)
y_np = np.random.randn(MICRO_BATCHES, MICRO_BATCH_SIZE, DIM).astype(np.float32)
# Weights: sharded on "pp", replicated on "dp"
w_spec = [DimSpec.from_raw("pp"), None, None]
b_spec = [DimSpec.from_raw("pp"), None]
w_sharded = ops.shard(nb.Tensor.from_dlpack(w_np), mesh, w_spec).realize()
b_sharded = ops.shard(nb.Tensor.from_dlpack(b_np), mesh, b_spec).realize()
# Data: sharded on "dp" (batch dim), replicated on "pp"
x_padded_np = np.concatenate(
[x_np, np.zeros((PP_SIZE, MICRO_BATCH_SIZE, DIM), dtype=np.float32)], axis=0
)
x_spec = [None, DimSpec.from_raw("dp"), None]
x_sharded = ops.shard(nb.Tensor.from_dlpack(x_padded_np), mesh, x_spec).realize()
y_sharded = ops.shard(nb.Tensor.from_dlpack(y_np), mesh, x_spec).realize()
# Pipeline state: sharded on both axes
state_spec = [DimSpec.from_raw("pp"), DimSpec.from_raw("dp"), None]
state_sharded = ops.shard(
nb.zeros((PP_SIZE, MICRO_BATCH_SIZE, DIM)), mesh, state_spec
).realize()
# Stage-0 mask
mask_np = np.eye(PP_SIZE, 1).reshape(PP_SIZE, 1, 1).astype(bool)
mask_spec = [DimSpec.from_raw("pp"), None, None]
mask_sharded = ops.shard(nb.Tensor.from_dlpack(mask_np), mesh, mask_spec).realize()
print(f"Weights: {w_sharded.shape} sharded on 'pp'")
print(f"Inputs: {x_sharded.shape} sharded on 'dp'")
Weights: [Dim(4), Dim(16), Dim(16)] sharded on 'pp'
Inputs: [Dim(8), Dim(4), Dim(16)] sharded on 'dp'
4. 2D Communication Setup#
With a 2D mesh, ppermute needs device-level permutations that shift only within each DP replica’s pipeline. For DP=2, PP=4, the 8 devices are numbered 0..7 where device dp*PP_SIZE + pp is at position (dp, pp). Each DP replica independently shifts its pipeline stages:
[11]:
# Build per-DP-replica pipeline permutations
idx = mesh.axis_names.index("pp")
size = mesh.shape[idx]
perm = []
for dp in range(DP_SIZE):
for src_pp in range(PP_SIZE):
src = dp * PP_SIZE + src_pp
dst = dp * PP_SIZE + (src_pp + 1) % size
perm.append((src, dst))
print(f"2D ppermute: {perm}")
# Vectorize stage_compute over the "pp" axis
step_fn = vmap(
stage_compute, in_axes=(0, 0, 0), out_axes=0, spmd_axis_name="pp", mesh=mesh
)
2D ppermute: [(0, 1), (1, 2), (2, 3), (3, 0), (4, 5), (5, 6), (6, 7), (7, 4)]
5. Loss Function and Gradient Computation#
We use nb.value_and_grad to get both the loss value and weight/bias gradients in one pass. argnums=(1, 2) differentiates w.r.t. weights and biases (arguments at positions 1 and 2):
[12]:
def pipeline_loss(inputs, weights, biases, state, mask, targets):
"""MSE through the full 2D-parallel pipeline."""
stream_outputs, _ = pipeline_loop(
inputs, weights, biases, state, mask, step_fn, perm, total_steps
)
indices = ops.arange(PP_SIZE, PP_SIZE + MICRO_BATCHES, dtype=nb.DType.int64)
valid_preds = ops.gather(stream_outputs, indices, axis=0)
diff = valid_preds - targets
return ops.mean(diff * diff)
grad_fn = nb.value_and_grad(pipeline_loss, argnums=(1, 2))
loss_nb, (w_grad, b_grad) = grad_fn(
x_sharded, w_sharded, b_sharded, state_sharded, mask_sharded, y_sharded
)
print(f"Loss: {loss_nb.item():.6f}")
print(f"Weight gradient shape: {w_grad.shape}")
print(f"Bias gradient shape: {b_grad.shape}")
w_grad_np = w_grad.to_numpy()
b_grad_np = b_grad.to_numpy()
Loss: 3828.785156
Weight gradient shape: [Dim(4), Dim(16), Dim(16)]
Bias gradient shape: [Dim(4), Dim(16)]
6. Verify Against JAX Reference#
We compare the 2D-parallel gradients against JAX’s sequential computation to confirm that sharding doesn’t affect numerical results:
[13]:
try:
import jax
import jax.numpy as jnp
def jax_ref(pw, pb, px, py):
def apply(curr, w, b):
return jax.nn.relu(curr @ w + b)
preds = []
for i in range(MICRO_BATCHES):
a = px[i]
for w, b in zip(pw, pb, strict=False):
a = apply(a, w, b)
preds.append(a)
preds = jnp.stack(preds)
return jnp.mean((preds - py) ** 2)
jax_vg = jax.value_and_grad(jax_ref, argnums=(0, 1))
loss_jax, (w_ref, b_ref) = jax_vg(w_np, b_np, x_np, y_np)
print(f"JAX loss: {loss_jax:.6f}")
print(f"Nabla loss: {loss_nb.item():.6f}")
w_diff = np.max(np.abs(w_grad_np - w_ref))
b_diff = np.max(np.abs(b_grad_np - b_ref))
print(f"Max weight grad diff: {w_diff:.6f}")
print(f"Max bias grad diff: {b_diff:.6f}")
assert w_diff < 5e-4 and b_diff < 5e-4, "Gradient mismatch!"
print("✅ 2D parallel gradients match JAX sequential reference")
except ImportError:
print("JAX not installed — skipping reference comparison")
JAX loss: 3828.785645
Nabla loss: 3828.785156
Max weight grad diff: 0.000092
Max bias grad diff: 0.000031
✅ 2D parallel gradients match JAX sequential reference
Key takeaways:
A 2D mesh partitions tensors along both pipeline and data axes
Weights are replicated across DP, sharded across PP — no all-reduce needed for forward
ppermutepermutations are constructed per-DP-replica to keep pipelines independentnb.value_and_graddifferentiates through the full 2D-parallel pipeline