Hessian Debugging Status Report - Physical Ops#
Overview#
This document serves as a status report and handover for the debugging efforts regarding Hessian computations in Nabla, specifically focusing on physical operations and nested forward-mode autodiff (fwd_fwd and rev_fwd).
Recent Progress#
Refactored Base Hierarchy:
BroadcastToPhysicalOpnow inherits fromShapeOp.PhysicalReduceOp(sum, mean, max, min) now inherits fromAxisOp.UnsqueezePhysicalOpandSqueezePhysicalOpnow inherit fromAxisOp.Benefit: These ops now leverage centralized
adapt_kwargsto automatically handlebatch_dimsduring physical execution.
Nested JVP Fix: Resolved issues with nested tangent propagation. Tangent-parent relationships are now correctly preserved and cleared using
_clear_jvp_cache.Debug Infrastructure: Integrated
NABLA_DEBUG_PHYSflag acrossbase.py,shape.py, andaxes.pyfor tracing physical shape transformations.
The Core Challenge: Batch Dimension Discrepancy#
In higher-order differentiation (like jacfwd or jacrev), Nabla uses vmap internally to calculate derivatives with respect to basis vectors. This introduces a “Batch Dimension Discrepancy”:
The Primal Situation: A primal operation (e.g.,
broadcast_to_physical) is called on a tensor withbatch_dims=B.The AD Situation: The
jvp_ruleorvjp_rulereceives a tangent or cotangent that has been vmapped, so it hasbatch_dims=B + E(whereEis the number of extra differentiation axes).The Failure: The AD rule typically re-uses the
kwargs(likeshapeoraxis) from the primal call. However, thesekwargsare relative to the primal’sbatch_dims. When applied to the tangent (which is rank-higher due toE), the physical operation fails because it doesn’t account for theEextra dimensions at the front.
Example:
If a primal broadcast_to_physical(x, shape=(2,3)) is called on x with batch_dims=0, and the JVP tangent has batch_dims=1 (shape [N, 2, 3]), the JVP rule calling broadcast_to_physical(tangent, shape=(2,3)) will fail because it tries to broadcast a rank-3 tensor to a rank-2 shape.
Current Status (Failing Tests)#
Target file: tests/unit/test_hessian_physical_ops.py
Status: 6 failures, 18 passed.
Key Failures & Blockers:#
test_hessian_broadcast_to_physical:fwd_fwdError: Still hitting rank mismatches in the JVP rule.Diagnosis: The
jvp_ruleneeds to “lift” thetarget_shapeby prepending the extra batch dimensions from the tangent.
Implicit Physical Paths:
test_hessian_implicit_broadcast_batch_dims_chainis failing withbroadcast_torank errors.Diagnosis: Operations like
broadcast_batch_dims(which usesreshapeandbroadcast_to) are sensitive to being nested. TheShapeOp.adapt_kwargshelper is prepending batch dims, but if it double-prepends or misses the “extra” ones, it fails.
Strategic Guidelines for the Next Agent#
1. Mathematical Tracing & Expected Behavior#
Trace the Nabla ops using
NABLA_DEBUG_OP_CALL=1.A Hessian is a derivative of a derivative. If you see
reshapeorbroadcast_toappearing in the trace, verify that their input/output shapes are exactly what you’d expect if you were manually differentiating the chain.
2. Negative Axis & KWarg Re-use#
Strategy: Leverage negative axes (
axis=-1) wherever possible. They are naturally “batch-dim-prefix-agnostic”.AD Rules: If a JVP/VJP rule uses a physical op, it must calculate the
extra_prefix(e.g.,tangent.batch_dims - primal.batch_dims) and adjust positive axes or prepend dimensions to the target shape.
3. Cleanup is Mandatory#
Always use
cleanup_caches()between test runs. Stale JVP/VJP caches will cause “tangent already has a parent” or “tangent mismatch” errors that are unrelated to your current code changes.
4. Avoid “Manual Shifts” in physical kernels#
The
kernelmethods should be pure. All adaptation should happen in__call__(for logical-to-physical setup) oradapt_kwargs(for sharding/batching).Inheritance from
AxisOporShapeOpis the preferred way to handle standard batch dimension propagation.
Final Note#
The goal is to make Nabla’s physical operations as “transparent” to vmap as the logical ones. This requires a robust mechanism for physical ops to detect and respect the extra batch dimension prefix introduced by AD transforms.