Program Transformations Uncovered#

This notebook demonstrates how transformations like vmap, grad or jit modify a Python program under the hood.

In order to visualize how Nabla works, we need two things:

  • nabla.xpr(<function>, *<args>) - Shows intermediate representation of a traced program: inputs → operations → outputs

  • nabla.jit(<function>, show_graph=True) - Shows compiled MAX graph (JIT only). The JIT-trafo transforms the intermediate representation into optimized machine code.

1. Defining and Visualizing a Python Function#

[1]:
import sys

try:
    import nabla as nb
except ImportError:
    import subprocess
    packages = ["nabla-ml"]
    subprocess.run([sys.executable, "-m", "pip", "install"] + packages, check=True)
    import nabla as nb

print(
    f"🎉 All libraries loaded successfully! Python {sys.version_info.major}.{sys.version_info.minor}"
)
🎉 All libraries loaded successfully! Python 3.10
[2]:
def function(input):
    return nb.sum(input * 2 * input, axes=0)


input = nb.randn((5,))
print("Base XPR:", nb.xpr(function, input))
print("\nres:", function(input))
Base XPR: { lambda (a:f32[5]:cpu(0)) ;
  let
    b:f32[1]:cpu(0) = unsqueeze[axes=[-1]]
    c:f32[5]:cpu(0) = mul a b
    d:f32[5]:cpu(0) = mul c a
    e:f32[1]:cpu(0) = sum[axes=[-1]] d
    f:f32[]:cpu(0) = squeeze[axes=[-1]] e
  in f }

res: 25.47862:f32[]:cpu(0)

3. Gradient Transformation#

nb.grad() transforms the program by adding vjp-nodes during backward pass.

[3]:
grad_function = nb.grad(function)
print("Gradient XPR:", nb.xpr(grad_function, input))
print("\nGradient res:", grad_function(input))
Gradient XPR: { lambda (a:f32[5]:cpu(0)) ;
  let
    b:f32[]:cpu(0) = 1.0
    c:f32[]:cpu(0) = shallow_copy b
    d:f32[1]:cpu(0) = unsqueeze[axes=[-1]] c
    e:f32[5]:cpu(0) = broadcast[shape=(5,)] d
    f:f32[5]:cpu(0) = shallow_copy a
    g:f32[1]:cpu(0) = unsqueeze[axes=[-1]]
    h:f32[5]:cpu(0) = mul f g
    i:f32[5]:cpu(0) = mul e h
    j:f32[5]:cpu(0) = mul e f
    k:f32[5]:cpu(0) = mul j g
    l:f32[5]:cpu(0) = add i k
  in l }

Gradient res: [7.0562096 1.6006289 3.914952  8.9635725 7.470232 ]:f32[5]:cpu(0)

4. Vectorization Transformation#

nb.vmap() adds batch processing. Blue numbers in shapes indicate batched dimensions (vs pink for regular dims).

[4]:
vmapped_grad_function = nb.vmap(nb.grad(function), in_axes=0)
batched_input = nb.randn((3, 5))
print("Vectorized XPR:", nb.xpr(vmapped_grad_function, batched_input))
print("\nVectorized res:", vmapped_grad_function(batched_input))
Vectorized XPR: { lambda (a:f32[3,5]:cpu(0)) ;
  let
    b:f32[]:cpu(0) = 1.0
    c:f32[1]:cpu(0) = unsqueeze_batch_dims[axes=[-1]] b
    d:f32[3]:cpu(0) = broadcast_batch_dims[shape=(3,)] c
    e:f32[3]:cpu(0) = shallow_copy d
    f:f32[3,1]:cpu(0) = unsqueeze[axes=[-1]] e
    g:f32[3,5]:cpu(0) = broadcast[shape=(5,)] f
    h:f32[3,5]:cpu(0) = incr_batch_dim_ctr a
    i:f32[3,5]:cpu(0) = shallow_copy h
    j:f32[1]:cpu(0) = unsqueeze[axes=[-1]]
    k:f32[3,5]:cpu(0) = mul i j
    l:f32[3,5]:cpu(0) = mul g k
    m:f32[3,5]:cpu(0) = mul g i
    n:f32[3,5]:cpu(0) = mul m j
    o:f32[3,5]:cpu(0) = add l n
    p:f32[3,5]:cpu(0) = decr_batch_dim_ctr o
  in p }

Vectorized res: [[ 7.0562096   1.6006289   3.914952    8.9635725   7.470232  ]
 [-3.9091115   3.8003538  -0.6054288  -0.4128754   1.6423941 ]
 [ 0.57617426  5.817094    3.0441508   0.48670006  1.775453  ]]:f32[3,5]:cpu(0)

5. Compilation Transformation with MAX#

[5]:
jitted_vmapped_grad_function = nb.jit(nb.vmap(nb.grad(function)), show_graph=True)
res = jitted_vmapped_grad_function(batched_input)
print("\nJitted Vectorized res:", res)
mo.graph @nabla_graph(%arg0: !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32> attributes {_kernel_library_paths = [], argument_names = ["input0"], result_names = ["output0"]} {
  %0 = mo.chain.create()
  %1 = mo.constant {value = #M.dense_array<2.000000e+00> : tensor<1xf32>} : !mo.tensor<[1], f32>
  %2 = mo.constant {value = #M.dense_array<1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00> : tensor<3x5xf32>} : !mo.tensor<[3, 5], f32>
  %3 = rmo.mul(%2, %arg0) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32>
  %4 = rmo.mul(%3, %1) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[1], f32>) -> !mo.tensor<[3, 5], f32>
  %5 = rmo.mul(%arg0, %1) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[1], f32>) -> !mo.tensor<[3, 5], f32>
  %6 = rmo.mul(%2, %5) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32>
  %7 = rmo.add(%6, %4) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32>
  mo.output %7 : !mo.tensor<[3, 5], f32>
} {counter = 24 : i64}

Jitted Vectorized res: [[ 7.0562096   1.6006289   3.914952    8.9635725   7.470232  ]
 [-3.9091115   3.8003538  -0.6054288  -0.4128754   1.6423941 ]
 [ 0.57617426  5.817094    3.0441508   0.48670006  1.775453  ]]:f32[3,5]:cpu(0)