Nabla SPMD Sharding / Partial-Propagation — Handover#
Date: 2026-03-06
Purpose#
This handover is for the next developer or AI agent continuing work on Nabla’s SPMD sharding, deferred-reduction propagation, and transform composition.
It should be read together with:
That review explains the architectural concerns and long-term direction. This handover focuses on:
what was already done,
what was refactored in this session,
what is still risky or incomplete,
and what to do next.
High-level state#
Nabla’s partial propagation system is working materially better than before.
Confirmed-good areas#
forward numerical partial propagation stress tests,
key sharding-transform interactions,
compile cache identity for sharded/effectful inputs,
all_reduce(sum)JVP support,removal of sharding-related
xfails in the focused transform suite.
Still true architecturally#
The implementation is still transitional. The main architectural concern remains:
the system represents deferred reductions in two forms:
free axis effects via
ShardingSpec.partial_sum_axesattached effects via
DimSpec.partial
That should still be unified later.
Original background#
The system started from a working but ad hoc partial propagation model:
row-parallel / contraction-style matmuls can yield partial sums,
linear/distributive ops may defer the
all_reduce,nonlinear ops must force reduction before applying the op.
Before this session, the following correctness fixes were already in place:
CastOpno longer incorrectly defers narrowing casts,ConcatenateOpno longer incorrectly defers when only some inputs are partial,GatherOpcorrectly defers,SliceTensorOpcorrectly defers.
The rigorous forward test baseline was already present in:
What was added/refactored in this session#
2. Partial passthrough logic was de-duplicated#
In:
added shared helpers for common partial-passthrough patterns:
preserve if all inputs carry the axis,
preserve if at most
ninputs carry the axis,generic axis-local predicate filtering.
Why#
Per-op logic was starting to duplicate the same patterns in slightly different forms. This makes the current API less messy while the bigger architectural rewrite is still pending.
4. vmap SPMD batching was hardened#
In:
_apply_shard() now copies forward source sharding metadata when adding a batch sharding axis.
Why#
The prior logic rebuilt sharding mostly from dim_specs, which made effect metadata fragile.
This is still not the final abstraction, but it is safer and more explicit.
5. Compile cache identity was fixed#
In:
changes:
compile now realizes lazy tensors before both cache-hit and cache-miss execution paths,
sharding cache identity now uses
ShardingSpec.effect_signature()instead of a weaker partial view.
Why#
Before this change, compile cache identity could ignore important sharding/effect metadata. That was a real correctness risk.
6. all_reduce(sum) JVP was implemented#
In:
added:
AllReduceOp.jvp_rule()PMeanOp.jvp_rule()
Also added a crucial guard:
when
all_reduceis only materializing an already-deferred partial effect from the primal path, the tangent must not be reduced again, or it double-counts.
Why#
This removed focused sharding-transform xfails and made forward-mode AD materially more robust in sharded paths.
7. Focused tests were improved, and xfails removed where actually fixed#
Updated or added tests in:
Notably:
removed sharding-related
xfails for JVP cases that were actually fixed,added direct regression coverage for compile sharding cache identity,
kept the stress partial-propagation suite as the baseline numerical oracle.
Targeted test commands that were run successfully#
Use these exact focused commands rather than broad full-suite runs.
Core partial propagation#
venv/bin/python -m pytest -q tests/unit/test_stress_partial_propagation.py
Transform + partial/sharding interactions#
venv/bin/python -m pytest -q tests/unit/test_transforms_sharded.py
Communication rigor#
venv/bin/python -m pytest -q tests/unit/test_communication_rigorous.py
Vmap + sharding + compile regression coverage#
venv/bin/python -m pytest -q tests/unit/test_vmap_sharding.py tests/unit/test_compile.py tests/unit/test_communication_rigorous.py
Broader focused bundle used during this session#
venv/bin/python -m pytest -q \
tests/unit/test_stress_partial_propagation.py \
tests/unit/test_vmap_sharding.py \
tests/unit/test_compile.py \
tests/unit/test_transforms_sharded.py \
tests/unit/test_communication_rigorous.py
All of the above passed at the end of this session.
Files that matter most now#
Core implementation#
Tests#
What is still incomplete / risky#
1. The core architectural split is still there#
This is still the main open issue.
The system still encodes deferred reduction effects in two places:
partial_sum_axesDimSpec.partial
That still causes complexity in:
spmd.pytransform composition
communication/output-spec logic
test reasoning
Recommendation#
Do the next real architectural step:
introduce one internal deferred-reduction effect object,
adapt current sharding structs to it,
then simplify
spmd.pyaround that single model.
This remains the highest-value next refactor.
2. spmd.py still has split reasoning paths and contraction-era bookkeeping#
Read carefully:
Still-important complexity points:
pure-partial fast path,
main factor-propagation path,
_save_multi_input_contracting_dims()_restore_cleared_contracting_dims()ghost_axes
These are signs of representation mismatch. They are not necessarily wrong, but they are still not the final architecture.
3. Transform-boundary semantics need more rigorous numerical coverage#
We improved this area, but it is still under-tested relative to the forward partial suite.
Especially worth expanding#
grad(vmap(f))wherefproduces/consumes partial tensors,vmap(grad(f))with sharded contraction paths,jit/compileover functions whose inputs or intermediates carry deferred reductions,nested transforms where batching and deferred reduction both interact.
Important testing rule#
Follow the same standard used in the original stress suite:
numerical equality against JAX/unsharded reference,
negative oracle proving sensitivity,
deterministic seeds.
4. There may still be sharding-adjacent xfails elsewhere#
In this session, the sharding JVP xfails in:
were fixed properly and removed.
But there are still other xfails elsewhere in the repository that were not addressed here.
Those should be inspected case by case rather than normalized away.
Suggested next tasks, in order#
Task 1 — unify deferred-reduction representation#
Goal:
replace the practical split between free and attached partial effects with one internal effect model.
Start from:
Do this incrementally, not as a big-bang rewrite.
Task 2 — create rigorous transform-boundary numerical stress tests#
Create a new focused test file for something like:
tests/unit/test_partial_transform_composition.py
Suggested contents:
gradover row-parallel matmul followed by deferred linear chains,vmapover functions that take partial tensors and return nonlinear results,compileon functions whose inputs differ only in deferred-reduction metadata,nested transform cases.
All with JAX reference + negative oracles where appropriate.
Task 3 — simplify spmd.py#
After introducing a unified effect model:
remove the pure-partial special path,
eliminate
ghost_axesas a separate concept if possible,fold the contraction save/restore logic into first-class effect propagation.
Task 4 — audit communication ops for full forward-mode support#
Now that all_reduce(sum) and pmean have JVP support, continue with any remaining communication ops if needed.
Candidates to inspect:
all_gatherreduce_scatterall_to_allppermute
Only add JVPs where semantics are clear and test them numerically.
Strong recommendation for the next agent#
Do not start by broadening op coverage again.
Start by improving the abstraction:
unify deferred-reduction representation,
improve transform-boundary tests,
then continue op/communication coverage.
That order is much more likely to produce a durable system.
Minimal “get started” checklist for the next developer#
Read:
this handover
Re-run:
the focused test commands above
Inspect:
Continue with:
unified effect model
new rigorous transform-composition tests
Bottom line#
The system is now in a meaningfully better state:
less duplicated policy logic,
better sharding metadata transport,
stronger compile cache correctness,
better forward-mode support,
fewer ignored sharding failures.
But the deepest architectural issue remains:
deferred reduction is still represented in two different internal forms.
That is the next real step if the goal is to make Nabla’s sharding compiler future-proof, readable, and maintainable.