Convolution AD Notes (Phase 1)#
This note documents the mathematical implementation used for Nabla conv2d in nabla/ops/convolution.py, including VJP and JVP rules in NHWC/RSCF form.
1) Forward Definition (conv2d)#
Let:
input:
X ∈ ℝ^{N×H×W×C_in}(NHWC)filter:
W ∈ ℝ^{K_h×K_w×C_in×C_out}(RSCF/HWIO for groups=1)optional bias:
b ∈ ℝ^{C_out}stride:
(s_h, s_w), dilation:(d_h, d_w)padding:
(p_t, p_b, p_l, p_r)
Output size:
H_out = floor((H + p_t + p_b - d_h (K_h-1) - 1) / s_h + 1)W_out = floor((W + p_l + p_r - d_w (K_w-1) - 1) / s_w + 1)
Elementwise:
Y[n,h,w,c_o] = Σ_{k_h,k_w,c_i} X_pad[n, h*s_h + k_h*d_h, w*s_w + k_w*d_w, c_i] * W[k_h,k_w,c_i,c_o] + b[c_o]
where X_pad is zero-padded in spatial dimensions.
2) VJP for conv2d#
Given cotangent G = ∂L/∂Y:
2.1 Input gradient#
∂L/∂X is the transposed-convolution of G with W under the same stride/dilation/padding convention:
dX = conv2d_transpose(G, W; stride, dilation, padding, output_paddings)
output_paddings are solved from shape consistency:
base_h = (H_out - 1)s_h - p_t - p_b + d_h(K_h - 1) + 1base_w = (W_out - 1)s_w - p_l - p_r + d_w(K_w - 1) + 1out_pad_h = H - base_h,out_pad_w = W - base_w
If either is negative, compute with max(.,0) and crop back to input shape.
2.2 Filter gradient#
By chain rule,
dW[k_h,k_w,c_i,c_o] = Σ_{n,h,w} X_pad[n, h*s_h + k_h*d_h, w*s_w + k_w*d_w, c_i] * G[n,h,w,c_o]
Implementation uses the standard conv trick by permuting dimensions and reusing conv2d:
X_perm = permute(X_pad, (C_in, H, W, N))G_perm = permute(G, (H_out, W_out, N, C_out))dW_perm = conv2d(X_perm, G_perm; stride=dilation, dilation=stride, padding=0)dW = permute(dW_perm, (K_h, K_w, C_in, C_out))
2.3 Bias gradient#
db[c_o] = Σ_{n,h,w} G[n,h,w,c_o]
i.e. reduce_sum(G, axis=[0,1,2]).
3) JVP for conv2d#
Using linearization of a bilinear map in (X, W):
dY = conv2d(dX, W; θ) + conv2d(X, dW; θ) + dB
where θ = (stride, dilation, padding, groups, layout kwargs) is held fixed.
This is exactly what the implementation computes.
4) Why this is correct (brief proof sketch)#
Convolution is affine in each argument and bilinear in (X, W) once hyperparameters are fixed.
VJP follows from adjointness of correlation/convolution and linearity of summation.
JVP follows from first-order expansion:
conv(X + εdX, W + εdW) = conv(X,W) + ε[conv(dX,W) + conv(X,dW)] + O(ε²)
and bias contributes additively as + ε dB.
Therefore:
reverse-mode returns the unique linear adjoint map wrt
X, W, b,forward-mode returns the directional derivative
J·v.
5) Backend constraints in this phase#
Current MAX backend behavior in this environment requires:
conv2d:dilation=(1,1)conv2d: grouped mode not yet usable without prepacked filters (groups=1enforced)conv2d_transpose:output_paddings=(0,0)only
Nabla validates these explicitly and raises clear frontend errors.