Value-and-Grads (CPU)#

This notebook demonstrates automatic differentiation using Nabla, which enables efficient computation of gradients for optimization algorithms. The nb.vjp() function computes both the forward pass value and provides a function for the backward pass (Vector-Jacobian Product).

Note: Check out the next tutorial on how to make this work on a GPU.

Setup and Imports#

[1]:
# Installation
import sys

IN_COLAB = "google.colab" in sys.modules

try:
    import nabla as nb
except ImportError:
    import subprocess

    subprocess.run(
        [
            sys.executable,
            "-m",
            "pip",
            "install",
            "modular",
            "--extra-index-url",
            "https://download.pytorch.org/whl/cpu",
            "--index-url",
            "https://dl.modular.com/public/nightly/python/simple/",
        ],
        check=True,
    )
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "nabla-ml", "--upgrade"], check=True
    )
    import nabla as nb

print(
    f"🎉 Nabla is ready! Running on Python {sys.version_info.major}.{sys.version_info.minor}"
)
🎉 Nabla is ready! Running on Python 3.12

Define Function with Automatic Differentiation#

Create a JIT-compiled function that computes both the value and gradients.

[2]:
def compute_with_gradients(x, y):
    def computation(x, y):
        return nb.sin(x * y) * 2

    value, vjp_fn = nb.vjp(computation, x, y)
    return value, vjp_fn(nb.ones_like(value))

Create Input Tensors#

Generate 2×3 tensors and move them to the target device.

[3]:
# Create tensors
a = nb.ndarange((2, 3))
b = nb.ndarange((2, 3))

Compute Values and Gradients#

Execute the function to get both the computed values and their gradients.

The output shows:

  • The first tensor contains the function values for sin(a * b) * 2.

  • The second tuple contains the gradients (∂f/∂a, ∂f/∂b) with respect to inputs a and b.

[4]:
# Compute and print results
value, grads = compute_with_gradients(a, b)
value, grads
[4]:
([[ 0.         1.682942  -1.513605 ]
  [ 0.8242369 -0.5758067 -0.2647035]]:f32[2,3],
 ([[ 0.         1.0806046 -2.6145744]
   [-5.4667816 -7.661276   9.912028 ]]:f32[2,3],
  [[ 0.         1.0806046 -2.6145744]
   [-5.4667816 -7.661276   9.912028 ]]:f32[2,3]))

Note

💡 Want to run this yourself?

  • 🚀 Google Colab: No setup required, runs in your browser

  • 📥 Local Jupyter: Download and run with your own Python environment