Program Transformations Uncovered#
This notebook demonstrates how transformations like vmap
, grad
or jit
modify a Python program under the hood.
In order to visualize how Nabla works, we need two things:
nabla.xpr(<function>, *<args>)
- Shows intermediate representation of a traced program: inputs → operations → outputsnabla.jit(<function>, show_graph=True)
- Shows compiled MAX graph (JIT only). The JIT-trafo transforms the intermediate representation into optimized machine code.
1. Defining and Visualizing a Python Function#
[1]:
import sys
try:
import nabla as nb
except ImportError:
import subprocess
packages = ["nabla-ml"]
subprocess.run([sys.executable, "-m", "pip", "install"] + packages, check=True)
import nabla as nb
print(
f"🎉 All libraries loaded successfully! Python {sys.version_info.major}.{sys.version_info.minor}"
)
🎉 All libraries loaded successfully! Python 3.10
[2]:
def function(input):
return nb.sum(input * 2 * input, axes=0)
input = nb.randn((5,))
print("Base XPR:", nb.xpr(function, input))
print("\nres:", function(input))
Base XPR: { lambda (a:f32[5]:cpu(0)) ;
let
b:f32[1]:cpu(0) = unsqueeze[axes=[-1]]
c:f32[5]:cpu(0) = mul a b
d:f32[5]:cpu(0) = mul c a
e:f32[1]:cpu(0) = sum[axes=[-1]] d
f:f32[]:cpu(0) = squeeze[axes=[-1]] e
in f }
res: 25.47862:f32[]:cpu(0)
3. Gradient Transformation#
nb.grad()
transforms the program by adding vjp-nodes
during backward pass.
[3]:
grad_function = nb.grad(function)
print("Gradient XPR:", nb.xpr(grad_function, input))
print("\nGradient res:", grad_function(input))
Gradient XPR: { lambda (a:f32[5]:cpu(0)) ;
let
b:f32[]:cpu(0) = 1.0
c:f32[]:cpu(0) = shallow_copy b
d:f32[1]:cpu(0) = unsqueeze[axes=[-1]] c
e:f32[5]:cpu(0) = broadcast[shape=(5,)] d
f:f32[5]:cpu(0) = shallow_copy a
g:f32[1]:cpu(0) = unsqueeze[axes=[-1]]
h:f32[5]:cpu(0) = mul f g
i:f32[5]:cpu(0) = mul e h
j:f32[5]:cpu(0) = mul e f
k:f32[5]:cpu(0) = mul j g
l:f32[5]:cpu(0) = add i k
in l }
Gradient res: [7.0562096 1.6006289 3.914952 8.9635725 7.470232 ]:f32[5]:cpu(0)
4. Vectorization Transformation#
nb.vmap()
adds batch processing. Blue numbers in shapes indicate batched dimensions (vs pink for regular dims).
[4]:
vmapped_grad_function = nb.vmap(nb.grad(function), in_axes=0)
batched_input = nb.randn((3, 5))
print("Vectorized XPR:", nb.xpr(vmapped_grad_function, batched_input))
print("\nVectorized res:", vmapped_grad_function(batched_input))
Vectorized XPR: { lambda (a:f32[3,5]:cpu(0)) ;
let
b:f32[]:cpu(0) = 1.0
c:f32[1]:cpu(0) = unsqueeze_batch_dims[axes=[-1]] b
d:f32[3]:cpu(0) = broadcast_batch_dims[shape=(3,)] c
e:f32[3]:cpu(0) = shallow_copy d
f:f32[3,1]:cpu(0) = unsqueeze[axes=[-1]] e
g:f32[3,5]:cpu(0) = broadcast[shape=(5,)] f
h:f32[3,5]:cpu(0) = incr_batch_dim_ctr a
i:f32[3,5]:cpu(0) = shallow_copy h
j:f32[1]:cpu(0) = unsqueeze[axes=[-1]]
k:f32[3,5]:cpu(0) = mul i j
l:f32[3,5]:cpu(0) = mul g k
m:f32[3,5]:cpu(0) = mul g i
n:f32[3,5]:cpu(0) = mul m j
o:f32[3,5]:cpu(0) = add l n
p:f32[3,5]:cpu(0) = decr_batch_dim_ctr o
in p }
Vectorized res: [[ 7.0562096 1.6006289 3.914952 8.9635725 7.470232 ]
[-3.9091115 3.8003538 -0.6054288 -0.4128754 1.6423941 ]
[ 0.57617426 5.817094 3.0441508 0.48670006 1.775453 ]]:f32[3,5]:cpu(0)
5. Compilation Transformation with MAX#
[5]:
jitted_vmapped_grad_function = nb.jit(nb.vmap(nb.grad(function)), show_graph=True)
res = jitted_vmapped_grad_function(batched_input)
print("\nJitted Vectorized res:", res)
mo.graph @nabla_graph(%arg0: !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32> attributes {_kernel_library_paths = [], argument_names = ["input0"], result_names = ["output0"]} {
%0 = mo.chain.create()
%1 = mo.constant {value = #M.dense_array<2.000000e+00> : tensor<1xf32>} : !mo.tensor<[1], f32>
%2 = mo.constant {value = #M.dense_array<1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00> : tensor<3x5xf32>} : !mo.tensor<[3, 5], f32>
%3 = rmo.mul(%2, %arg0) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32>
%4 = rmo.mul(%3, %1) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[1], f32>) -> !mo.tensor<[3, 5], f32>
%5 = rmo.mul(%arg0, %1) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[1], f32>) -> !mo.tensor<[3, 5], f32>
%6 = rmo.mul(%2, %5) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32>
%7 = rmo.add(%6, %4) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32>
mo.output %7 : !mo.tensor<[3, 5], f32>
} {counter = 24 : i64}
Jitted Vectorized res: [[ 7.0562096 1.6006289 3.914952 8.9635725 7.470232 ]
[-3.9091115 3.8003538 -0.6054288 -0.4128754 1.6423941 ]
[ 0.57617426 5.817094 3.0441508 0.48670006 1.775453 ]]:f32[3,5]:cpu(0)