Example 7: 2D Parallel Training (Pipeline + Data Parallelism)#

This notebook extends pipeline parallelism (Example 6) by adding a data-parallel dimension, creating a 2D device mesh:

             Pipeline stages →
            Stage 0  Stage 1  Stage 2  Stage 3
Data  DP 0   [w0]     [w1]     [w2]     [w3]    ← same weights, different data
Par.  DP 1   [w0]     [w1]     [w2]     [w3]    ← same weights, different data

Key idea: Weights are sharded across pipeline stages and replicated across data-parallel replicas. Input batches are sharded across DP replicas.

We’ll:

  1. Build a 2D DeviceMesh("dp", "pp")

  2. Shard weights on "pp", data on "dp"

  3. Use the same pipeline primitives from Example 6

  4. Compute gradients with nb.value_and_grad

[7]:
import numpy as np

import nabla as nb
from nabla import ops
from nabla.core.sharding import DeviceMesh, DimSpec
from nabla.ops import communication
from nabla.transforms import vmap

print("Nabla 2D Parallelism example")
Nabla 2D Parallelism example

1. Configuration and Device Mesh#

The 2D mesh has shape (DP_SIZE, PP_SIZE) with named axes "dp" and "pp". Each device is identified by a (dp_rank, pp_rank) pair:

[8]:
# 2D mesh dimensions
DP_SIZE = 2          # Data-parallel replicas
PP_SIZE = 4          # Pipeline stages
MICRO_BATCHES = 4
MICRO_BATCH_SIZE = 4
DIM = 16

mesh = DeviceMesh("2d", (DP_SIZE, PP_SIZE), ("dp", "pp"))
print(f"2D device mesh: {DP_SIZE} DP replicas × {PP_SIZE} PP stages = {DP_SIZE * PP_SIZE} devices")
2D device mesh: 2 DP replicas × 4 PP stages = 8 devices

2. Pipeline Primitives (same as Example 6)#

The stage compute, step, and loop functions are identical to Example 6. Only the sharding specification changes — the mesh now has two axes:

[9]:
def stage_compute(x, w, b):
    """One pipeline stage: linear + ReLU."""
    return ops.relu(ops.matmul(x, w) + b)


def pipeline_step(current_state, fresh_input, weight_stack, bias_stack, mask_0, step_fn, perm):
    """Compute → shift → extract result → inject fresh input."""
    computed = step_fn(current_state, weight_stack, bias_stack)
    shifted = communication.ppermute(computed, perm)
    res_part = ops.where(mask_0, shifted, ops.zeros_like(shifted))
    result = ops.reduce_sum(res_part, axis=0)
    next_state = ops.where(mask_0, fresh_input, shifted)
    return next_state, result


def pipeline_loop(padded_inputs, weight_stack, bias_stack, current_state, mask_0, step_fn, perm, total_steps):
    """Stream micro-batches through the pipeline."""
    results = []
    for t in range(total_steps):
        start_idx = (t, 0, 0)
        slice_size = (1, MICRO_BATCH_SIZE, DIM)
        fraction = ops.slice_tensor(padded_inputs, start=start_idx, size=slice_size)
        fresh = ops.squeeze(fraction, axis=0)
        current_state, res = pipeline_step(current_state, fresh, weight_stack, bias_stack, mask_0, step_fn, perm)
        results.append(res)
    return ops.stack(results, axis=0), current_state

3. Shard Data for 2D Parallelism#

The key difference from Example 6: we now specify two-axis sharding.

Tensor

Sharding

Meaning

Weights w

("pp", None, None)

Partitioned across pipeline stages, replicated across DP

Biases b

("pp", None)

Same as weights

Inputs x

(None, "dp", None)

Replicated across PP, partitioned across DP

State

("pp", "dp", None)

Partitioned on both axes

[10]:
np.random.seed(42)
total_steps = MICRO_BATCHES + PP_SIZE

w_np = np.random.randn(PP_SIZE, DIM, DIM).astype(np.float32)
b_np = np.random.randn(PP_SIZE, DIM).astype(np.float32)
x_np = np.random.randn(MICRO_BATCHES, MICRO_BATCH_SIZE, DIM).astype(np.float32)
y_np = np.random.randn(MICRO_BATCHES, MICRO_BATCH_SIZE, DIM).astype(np.float32)

# Weights: sharded on "pp", replicated on "dp"
w_spec = [DimSpec.from_raw("pp"), None, None]
b_spec = [DimSpec.from_raw("pp"), None]
w_sharded = ops.shard(nb.Tensor.from_dlpack(w_np), mesh, w_spec).realize()
b_sharded = ops.shard(nb.Tensor.from_dlpack(b_np), mesh, b_spec).realize()

# Data: sharded on "dp" (batch dim), replicated on "pp"
x_padded_np = np.concatenate(
    [x_np, np.zeros((PP_SIZE, MICRO_BATCH_SIZE, DIM), dtype=np.float32)], axis=0
)
x_spec = [None, DimSpec.from_raw("dp"), None]
x_sharded = ops.shard(nb.Tensor.from_dlpack(x_padded_np), mesh, x_spec).realize()
y_sharded = ops.shard(nb.Tensor.from_dlpack(y_np), mesh, x_spec).realize()

# Pipeline state: sharded on both axes
state_spec = [DimSpec.from_raw("pp"), DimSpec.from_raw("dp"), None]
state_sharded = ops.shard(
    nb.zeros((PP_SIZE, MICRO_BATCH_SIZE, DIM)), mesh, state_spec
).realize()

# Stage-0 mask
mask_np = np.eye(PP_SIZE, 1).reshape(PP_SIZE, 1, 1).astype(bool)
mask_spec = [DimSpec.from_raw("pp"), None, None]
mask_sharded = ops.shard(nb.Tensor.from_dlpack(mask_np), mesh, mask_spec).realize()

print(f"Weights: {w_sharded.shape} sharded on 'pp'")
print(f"Inputs:  {x_sharded.shape} sharded on 'dp'")
Weights: [Dim(4), Dim(16), Dim(16)] sharded on 'pp'
Inputs:  [Dim(8), Dim(4), Dim(16)] sharded on 'dp'

4. 2D Communication Setup#

With a 2D mesh, ppermute needs device-level permutations that shift only within each DP replica’s pipeline. For DP=2, PP=4, the 8 devices are numbered 0..7 where device dp*PP_SIZE + pp is at position (dp, pp). Each DP replica independently shifts its pipeline stages:

[11]:
# Build per-DP-replica pipeline permutations
idx = mesh.axis_names.index("pp")
size = mesh.shape[idx]
perm = []
for dp in range(DP_SIZE):
    for src_pp in range(PP_SIZE):
        src = dp * PP_SIZE + src_pp
        dst = dp * PP_SIZE + (src_pp + 1) % size
        perm.append((src, dst))
print(f"2D ppermute: {perm}")

# Vectorize stage_compute over the "pp" axis
step_fn = vmap(
    stage_compute, in_axes=(0, 0, 0), out_axes=0, spmd_axis_name="pp", mesh=mesh
)
2D ppermute: [(0, 1), (1, 2), (2, 3), (3, 0), (4, 5), (5, 6), (6, 7), (7, 4)]

5. Loss Function and Gradient Computation#

We use nb.value_and_grad to get both the loss value and weight/bias gradients in one pass. argnums=(1, 2) differentiates w.r.t. weights and biases (arguments at positions 1 and 2):

[12]:
def pipeline_loss(inputs, weights, biases, state, mask, targets):
    """MSE through the full 2D-parallel pipeline."""
    stream_outputs, _ = pipeline_loop(
        inputs, weights, biases, state, mask, step_fn, perm, total_steps
    )
    indices = ops.arange(PP_SIZE, PP_SIZE + MICRO_BATCHES, dtype=nb.DType.int64)
    valid_preds = ops.gather(stream_outputs, indices, axis=0)
    diff = valid_preds - targets
    return ops.mean(diff * diff)


grad_fn = nb.value_and_grad(pipeline_loss, argnums=(1, 2))
loss_nb, (w_grad, b_grad) = grad_fn(
    x_sharded, w_sharded, b_sharded, state_sharded, mask_sharded, y_sharded
)

print(f"Loss: {loss_nb.item():.6f}")
print(f"Weight gradient shape: {w_grad.shape}")
print(f"Bias gradient shape:   {b_grad.shape}")

w_grad_np = w_grad.to_numpy()
b_grad_np = b_grad.to_numpy()
Loss: 3828.785156
Weight gradient shape: [Dim(4), Dim(16), Dim(16)]
Bias gradient shape:   [Dim(4), Dim(16)]

6. Verify Against JAX Reference#

We compare the 2D-parallel gradients against JAX’s sequential computation to confirm that sharding doesn’t affect numerical results:

[13]:
try:
    import jax
    import jax.numpy as jnp

    def jax_ref(pw, pb, px, py):
        def apply(curr, w, b):
            return jax.nn.relu(curr @ w + b)
        preds = []
        for i in range(MICRO_BATCHES):
            a = px[i]
            for w, b in zip(pw, pb, strict=False):
                a = apply(a, w, b)
            preds.append(a)
        preds = jnp.stack(preds)
        return jnp.mean((preds - py) ** 2)

    jax_vg = jax.value_and_grad(jax_ref, argnums=(0, 1))
    loss_jax, (w_ref, b_ref) = jax_vg(w_np, b_np, x_np, y_np)

    print(f"JAX loss:   {loss_jax:.6f}")
    print(f"Nabla loss: {loss_nb.item():.6f}")

    w_diff = np.max(np.abs(w_grad_np - w_ref))
    b_diff = np.max(np.abs(b_grad_np - b_ref))
    print(f"Max weight grad diff: {w_diff:.6f}")
    print(f"Max bias grad diff:   {b_diff:.6f}")
    assert w_diff < 5e-4 and b_diff < 5e-4, "Gradient mismatch!"
    print("✅ 2D parallel gradients match JAX sequential reference")

except ImportError:
    print("JAX not installed — skipping reference comparison")
JAX loss:   3828.785645
Nabla loss: 3828.785156
Max weight grad diff: 0.000092
Max bias grad diff:   0.000031
✅ 2D parallel gradients match JAX sequential reference

Key takeaways:

  • A 2D mesh partitions tensors along both pipeline and data axes

  • Weights are replicated across DP, sharded across PP — no all-reduce needed for forward

  • ppermute permutations are constructed per-DP-replica to keep pipelines independent

  • nb.value_and_grad differentiates through the full 2D-parallel pipeline