JAX vs. Nabla: Training an MLP (CPU)#

This notebook provides a thorough comparison between Nabla and JAX for training an MLP to learn a complex 8-period sine function. We compare:

  1. Different execution modes (eager vs JIT)

  2. Performance characteristics

  3. Final accuracy

  4. Numerical precision handling

Problem: Train an MLP to learn f(x) = sin(8 * * x) / 2 + 0.5 - a complex high-frequency sine function.

Key Focus: Ensure fair comparison by using consistent float32 precision across both frameworks.

[ ]:
# Installation
import sys

IN_COLAB = "google.colab" in sys.modules

try:
    import jax
    import jax.numpy as jnp
    from jax import jit, value_and_grad

    import nabla as nb
except ImportError:
    import subprocess

    subprocess.run(
        [
            sys.executable,
            "-m",
            "pip",
            "install",
            "modular",
            "--extra-index-url",
            "https://download.pytorch.org/whl/cpu",
            "--index-url",
            "https://dl.modular.com/public/nightly/python/simple/",
        ],
        check=True,
    )
    subprocess.run(
        [
            sys.executable,
            "-m",
            "pip",
            "install",
            "nabla-ml",
            "jax",
            "jaxlib",
            "--upgrade",
        ],
        check=True,
    )
    import jax
    import jax.numpy as jnp
    from jax import jit, value_and_grad

    import nabla as nb

# Import other required libraries
import time

import matplotlib.pyplot as plt
import numpy as np

# Set JAX to use float32 consistently
from jax import config

config.update("jax_enable_x64", False)

print(
    f"🎉 Nabla and JAX are ready! Running on Python {sys.version_info.major}.{sys.version_info.minor}"
)
🎉 Nabla and JAX are ready! Running on Python 3.12
[3]:
# Configuration parameters
BATCH_SIZE = 128
LAYERS = [1, 64, 128, 256, 128, 64, 1]  # MLP architecture
LEARNING_RATE = 0.001
NUM_EPOCHS = 5000
PRINT_INTERVAL = 1000
SIN_PERIODS = 8
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)

2. Core Model Functions (Nabla and JAX)#

[4]:
# ===== Shared Data Generation =====
def create_sin_dataset(batch_size: int = BATCH_SIZE) -> tuple[np.ndarray, np.ndarray]:
    """Generate training data for the sinusoidal function using NumPy."""
    x_np = np.random.uniform(0.0, 1.0, size=(batch_size, 1)).astype(np.float32)
    targets_np = (np.sin(SIN_PERIODS * 2.0 * np.pi * x_np) / 2.0 + 0.5).astype(
        np.float32
    )
    return x_np, targets_np


# ===== Nabla Implementation =====
def nb_mlp_forward(x: nb.Array, params: list[nb.Array]) -> nb.Array:
    """Forward pass of the MLP with ReLU activations (Nabla)."""
    output = x
    for i in range(0, len(params) - 1, 2):
        w, b = params[i], params[i + 1]
        output = nb.matmul(output, w) + b
        if i < len(params) - 2:  # No ReLU for the output layer
            output = nb.relu(output)
    return output


def nb_mean_squared_error(predictions: nb.Array, targets: nb.Array) -> nb.Array:
    """Compute mean squared error loss (Nabla)."""
    diff = predictions - targets
    squared_errors = diff * diff
    batch_size = nb.array(predictions.shape[0], dtype=nb.DType.float32)
    return nb.sum(squared_errors) / batch_size


def nb_create_sin_dataset(batch_size: int = BATCH_SIZE) -> tuple[nb.Array, nb.Array]:
    """Generate training data for the sinusoidal function (Nabla)."""
    x_np, targets_np = create_sin_dataset(batch_size)
    x = nb.Array.from_numpy(x_np)
    targets = nb.Array.from_numpy(targets_np)
    return x, targets


def nb_initialize_for_complex_function(
    layers: list[int], seed: int = GLOBAL_SEED
) -> list[nb.Array]:
    """Initialize weights using He Normal initialization (Nabla)."""
    np.random.seed(seed)
    params = []
    for i in range(len(layers) - 1):
        fan_in, fan_out = layers[i], layers[i + 1]
        std = (2.0 / fan_in) ** 0.5
        w_np = std * np.random.normal(size=(fan_in, fan_out)).astype(np.float32)
        b_np = np.zeros((fan_out,), dtype=np.float32)
        w = nb.Array.from_numpy(w_np)
        b = nb.Array.from_numpy(b_np)
        params.extend([w, b])
    return params


# ===== JAX Implementation =====
def jax_mlp_forward(x: jnp.ndarray, params: list[jnp.ndarray]) -> jnp.ndarray:
    """Forward pass of the MLP with ReLU activations (JAX)."""
    output = x
    for i in range(0, len(params) - 1, 2):
        w, b = params[i], params[i + 1]
        output = jnp.matmul(output, w) + b
        if i < len(params) - 2:  # No ReLU for the output layer
            output = jax.nn.relu(output)
    return output


def jax_mean_squared_error(
    predictions: jnp.ndarray, targets: jnp.ndarray
) -> jnp.ndarray:
    """Compute mean squared error loss (JAX)."""
    diff = predictions - targets
    squared_errors = diff * diff
    batch_size = jnp.array(predictions.shape[0], dtype=jnp.float32)
    return jnp.sum(squared_errors) / batch_size


def jax_create_sin_dataset(
    batch_size: int = BATCH_SIZE,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Generate training data for the sinusoidal function (JAX)."""
    x_np, targets_np = create_sin_dataset(batch_size)
    x = jnp.array(x_np)
    targets = jnp.array(targets_np)
    return x, targets


def jax_initialize_for_complex_function(
    layers: list[int], seed: int = GLOBAL_SEED
) -> list[jnp.ndarray]:
    """Initialize weights using He Normal initialization (JAX)."""
    np.random.seed(seed)
    params = []
    for i in range(len(layers) - 1):
        fan_in, fan_out = layers[i], layers[i + 1]
        std = (2.0 / fan_in) ** 0.5
        w_np = std * np.random.normal(size=(fan_in, fan_out)).astype(np.float32)
        b_np = np.zeros((fan_out,), dtype=np.float32)
        w = jnp.array(w_np)
        b = jnp.array(b_np)
        params.extend([w, b])
    return params

3. Optimizer Implementations (Nabla and JAX)#

[5]:
# ===== Shared Optimizer Implementation =====
def adamw_step(
    params: list,
    gradients: list,
    m_states: list,
    v_states: list,
    step: int,
    learning_rate: float = LEARNING_RATE,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    weight_decay: float = 0.01,
) -> tuple[list, list, list]:
    """Shared AdamW optimization step with weight decay."""
    updated_params = []
    updated_m = []
    updated_v = []
    bc1, bc2 = 1.0 - beta1**step, 1.0 - beta2**step

    for param, grad, m, v in zip(params, gradients, m_states, v_states, strict=False):
        new_m = beta1 * m + (1.0 - beta1) * grad
        new_v = beta2 * v + (1.0 - beta2) * (grad * grad)
        m_corrected = new_m / bc1
        v_corrected = new_v / bc2
        update = m_corrected / (v_corrected**0.5 + eps) + weight_decay * param
        new_param = param - learning_rate * update
        updated_params.append(new_param)
        updated_m.append(new_m)
        updated_v.append(new_v)

    return updated_params, updated_m, updated_v


def init_adamw_state(params: list) -> tuple[list, list]:
    """Shared function to initialize AdamW optimizer states."""
    # Create empty arrays with the same shape as params
    if isinstance(params[0], nb.Array):
        m_states = [nb.zeros_like(param) for param in params]
        v_states = [nb.zeros_like(param) for param in params]
    else:
        m_states = [jnp.zeros_like(param) for param in params]
        v_states = [jnp.zeros_like(param) for param in params]
    return m_states, v_states


# ===== Nabla Optimizer =====
def nb_adamw_step(
    params: list[nb.Array],
    gradients: list[nb.Array],
    m_states: list[nb.Array],
    v_states: list[nb.Array],
    step: int,
    learning_rate: float = LEARNING_RATE,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    weight_decay: float = 0.01,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array]]:
    """AdamW optimization step with weight decay (Nabla)."""
    return adamw_step(
        params,
        gradients,
        m_states,
        v_states,
        step,
        learning_rate,
        beta1,
        beta2,
        eps,
        weight_decay,
    )


def nb_init_adamw_state(
    params: list[nb.Array],
) -> tuple[list[nb.Array], list[nb.Array]]:
    """Initialize AdamW optimizer states (Nabla)."""
    return init_adamw_state(params)


# ===== JAX Optimizer =====
def jax_adamw_step(
    params: list[jnp.ndarray],
    gradients: list[jnp.ndarray],
    m_states: list[jnp.ndarray],
    v_states: list[jnp.ndarray],
    step: int,
    learning_rate: float = LEARNING_RATE,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    weight_decay: float = 0.01,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], list[jnp.ndarray]]:
    """AdamW optimization step with weight decay (JAX)."""
    return adamw_step(
        params,
        gradients,
        m_states,
        v_states,
        step,
        learning_rate,
        beta1,
        beta2,
        eps,
        weight_decay,
    )


def jax_init_adamw_state(
    params: list[jnp.ndarray],
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
    """Initialize AdamW optimizer states (JAX)."""
    return init_adamw_state(params)

4. Training Step Implementations#

[6]:
# ===== Nabla Training Steps =====
def nb_train_step_no_jit(
    x: nb.Array,
    targets: nb.Array,
    params: list[nb.Array],
    m_states: list[nb.Array],
    v_states: list[nb.Array],
    step: int,
    learning_rate: float,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array], nb.Array]:
    """Training step in eager mode (Nabla - no JIT)."""
    loss_value, param_gradients = nb.value_and_grad(
        lambda *p: nb_mean_squared_error(nb_mlp_forward(x, p), targets),
        argnums=list(range(len(params))),
    )(*params)

    updated_params, updated_m, updated_v = nb_adamw_step(
        params, param_gradients, m_states, v_states, step, learning_rate
    )

    return updated_params, updated_m, updated_v, loss_value


@nb.djit
def nb_train_step_djit(
    x: nb.Array,
    targets: nb.Array,
    params: list[nb.Array],
    m_states: list[nb.Array],
    v_states: list[nb.Array],
    step: int,
    learning_rate: float,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array], nb.Array]:
    """Training step with dynamic JIT (Nabla)."""
    loss_value, param_gradients = nb.value_and_grad(
        lambda *p: nb_mean_squared_error(nb_mlp_forward(x, p), targets),
        argnums=list(range(len(params))),
    )(*params)

    updated_params, updated_m, updated_v = nb_adamw_step(
        params, param_gradients, m_states, v_states, step, learning_rate
    )

    return updated_params, updated_m, updated_v, loss_value


@nb.jit
def nb_train_step_jit(
    x: nb.Array,
    targets: nb.Array,
    params: list[nb.Array],
    m_states: list[nb.Array],
    v_states: list[nb.Array],
    step: int,
    learning_rate: float,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array], nb.Array]:
    """Training step with static JIT (Nabla)."""
    loss_value, param_gradients = nb.value_and_grad(
        lambda *p: nb_mean_squared_error(nb_mlp_forward(x, p), targets),
        argnums=list(range(len(params))),
    )(*params)

    updated_params, updated_m, updated_v = nb_adamw_step(
        params, param_gradients, m_states, v_states, step, learning_rate
    )

    return updated_params, updated_m, updated_v, loss_value


# ===== JAX Training Steps =====
def jax_train_step_no_jit(
    x: jnp.ndarray,
    targets: jnp.ndarray,
    params: list[jnp.ndarray],
    m_states: list[jnp.ndarray],
    v_states: list[jnp.ndarray],
    step: int,
    learning_rate: float,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], list[jnp.ndarray], jnp.ndarray]:
    """Training step in eager mode (JAX - no JIT)."""

    def loss_fn(params_inner):
        predictions = jax_mlp_forward(x, params_inner)
        return jax_mean_squared_error(predictions, targets)

    loss_value, param_gradients = value_and_grad(loss_fn)(params)

    updated_params, updated_m, updated_v = jax_adamw_step(
        params, param_gradients, m_states, v_states, step, learning_rate
    )

    return updated_params, updated_m, updated_v, loss_value


@jit
def jax_train_step_jit(
    x: jnp.ndarray,
    targets: jnp.ndarray,
    params: list[jnp.ndarray],
    m_states: list[jnp.ndarray],
    v_states: list[jnp.ndarray],
    step: int,
    learning_rate: float,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], list[jnp.ndarray], jnp.ndarray]:
    """Training step with JIT (JAX)."""

    def loss_fn(params_inner):
        predictions = jax_mlp_forward(x, params_inner)
        return jax_mean_squared_error(predictions, targets)

    loss_value, param_gradients = value_and_grad(loss_fn)(params)

    updated_params, updated_m, updated_v = jax_adamw_step(
        params, param_gradients, m_states, v_states, step, learning_rate
    )

    return updated_params, updated_m, updated_v, loss_value

5. Experiment Runner Functions#

[7]:
def run_nabla_experiment(train_step_func, jit_mode: str):
    """Run training experiment with Nabla."""
    print(f"\n{'=' * 50}\nStarting Nabla Training with: {jit_mode}\n{'=' * 50}")
    print(f"Architecture: {LAYERS}")
    print(f"Initial Learning Rate: {LEARNING_RATE}")

    # Initialize model and optimizer
    params = nb_initialize_for_complex_function(LAYERS, seed=GLOBAL_SEED)
    m_states, v_states = nb_init_adamw_state(params)

    # Initial evaluation
    x_init, targets_init = nb_create_sin_dataset(BATCH_SIZE)
    predictions_init = nb_mlp_forward(x_init, params)
    initial_loss = (
        nb_mean_squared_error(predictions_init, targets_init).to_numpy().item()
    )
    print(f"Initial Loss: {initial_loss:.6f}")

    # Training loop
    avg_loss, total_time, compile_time = 0.0, 0.0, 0.0
    print("Starting training loop...")

    for epoch in range(1, NUM_EPOCHS + 1):
        epoch_start = time.time()
        current_lr = LEARNING_RATE * (0.95 ** (epoch // 1000))
        x, targets = nb_create_sin_dataset(BATCH_SIZE)

        if "JIT" in jit_mode and epoch == 1:
            compile_start = time.time()
            updated_params, updated_m, updated_v, loss_value = train_step_func(
                x, targets, params, m_states, v_states, epoch, current_lr
            )
            compile_time = time.time() - compile_start
        else:
            updated_params, updated_m, updated_v, loss_value = train_step_func(
                x, targets, params, m_states, v_states, epoch, current_lr
            )

        params, m_states, v_states = updated_params, updated_m, updated_v
        avg_loss += loss_value.to_numpy().item()
        total_time += time.time() - epoch_start

        if epoch % PRINT_INTERVAL == 0:
            print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss / PRINT_INTERVAL:.6f}")
            avg_loss = 0.0

    # Final evaluation
    x_test = nb.Array.from_numpy(
        np.linspace(0, 1, 1000).reshape(-1, 1).astype(np.float32)
    )
    targets_test = nb.Array.from_numpy(
        (np.sin(SIN_PERIODS * 2.0 * np.pi * x_test.to_numpy()) / 2.0 + 0.5).astype(
            np.float32
        )
    )
    predictions_test = nb_mlp_forward(x_test, params)
    test_loss = nb_mean_squared_error(predictions_test, targets_test)
    final_loss = test_loss.to_numpy().item()
    correlation = np.corrcoef(
        predictions_test.to_numpy().flatten(), targets_test.to_numpy().flatten()
    )[0, 1]

    # Return results
    avg_epoch_time = (
        (total_time - compile_time) / (NUM_EPOCHS - 1)
        if "JIT" in jit_mode
        else total_time / NUM_EPOCHS
    )
    return {
        "framework": "Nabla",
        "mode": jit_mode,
        "total_time": total_time,
        "compile_time": compile_time,
        "avg_epoch_time": avg_epoch_time,
        "final_loss": final_loss,
        "correlation": correlation,
        "predictions": predictions_test.to_numpy(),
        "targets": targets_test.to_numpy(),
    }


def run_jax_experiment(train_step_func, jit_mode: str):
    """Run training experiment with JAX."""
    print(f"\n{'=' * 50}\nStarting JAX Training with: {jit_mode}\n{'=' * 50}")
    print(f"Architecture: {LAYERS}")
    print(f"Initial Learning Rate: {LEARNING_RATE}")

    # Initialize model and optimizer
    params = jax_initialize_for_complex_function(LAYERS, seed=GLOBAL_SEED)
    m_states, v_states = jax_init_adamw_state(params)

    # Initial evaluation
    x_init, targets_init = jax_create_sin_dataset(BATCH_SIZE)
    predictions_init = jax_mlp_forward(x_init, params)
    initial_loss = jax_mean_squared_error(predictions_init, targets_init).item()
    print(f"Initial Loss: {initial_loss:.6f}")

    # Training loop
    avg_loss, total_time, compile_time = 0.0, 0.0, 0.0
    print("Starting training loop...")

    for epoch in range(1, NUM_EPOCHS + 1):
        epoch_start = time.time()
        current_lr = LEARNING_RATE * (0.95 ** (epoch // 1000))
        x, targets = jax_create_sin_dataset(BATCH_SIZE)

        if "JIT" in jit_mode and epoch == 1:
            compile_start = time.time()
            updated_params, updated_m, updated_v, loss_value = train_step_func(
                x, targets, params, m_states, v_states, epoch, current_lr
            )
            compile_time = time.time() - compile_start
        else:
            updated_params, updated_m, updated_v, loss_value = train_step_func(
                x, targets, params, m_states, v_states, epoch, current_lr
            )

        params, m_states, v_states = updated_params, updated_m, updated_v
        avg_loss += loss_value.item()
        total_time += time.time() - epoch_start

        if epoch % PRINT_INTERVAL == 0:
            print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss / PRINT_INTERVAL:.6f}")
            avg_loss = 0.0

    # Final evaluation
    x_test = jnp.linspace(0, 1, 1000, dtype=jnp.float32).reshape(-1, 1)
    targets_test = (jnp.sin(SIN_PERIODS * 2.0 * jnp.pi * x_test) / 2.0 + 0.5).astype(
        jnp.float32
    )
    predictions_test = jax_mlp_forward(x_test, params)
    test_loss = jax_mean_squared_error(predictions_test, targets_test)
    final_loss = test_loss.item()
    correlation = np.corrcoef(predictions_test.reshape(-1), targets_test.reshape(-1))[
        0, 1
    ]

    # Return results
    avg_epoch_time = (
        (total_time - compile_time) / (NUM_EPOCHS - 1)
        if "JIT" in jit_mode
        else total_time / NUM_EPOCHS
    )
    return {
        "framework": "JAX",
        "mode": jit_mode,
        "total_time": total_time,
        "compile_time": compile_time,
        "avg_epoch_time": avg_epoch_time,
        "final_loss": final_loss,
        "correlation": correlation,
        "predictions": np.array(predictions_test),
        "targets": np.array(targets_test),
    }

6. Run Experiments and Compare Results#

[8]:
# Run experiments
results = []

# Run Nabla experiments
print("\nRunning Nabla Eager Mode Experiment...")
results.append(run_nabla_experiment(nb_train_step_no_jit, "Eager Mode (No JIT)"))

print("\nRunning Nabla Dynamic JIT Experiment...")
results.append(run_nabla_experiment(nb_train_step_djit, "Dynamic JIT (nb.djit)"))

print("\nRunning Nabla Static JIT Experiment...")
results.append(run_nabla_experiment(nb_train_step_jit, "Static JIT (nb.jit)"))

# Run JAX experiments
print("\nRunning JAX Eager Mode Experiment...")
results.append(run_jax_experiment(jax_train_step_no_jit, "Eager Mode (No JIT)"))

print("\nRunning JAX JIT Experiment...")
results.append(run_jax_experiment(jax_train_step_jit, "JIT"))

# Plot comparison of predictions
plt.figure(figsize=(15, 12))
plt.suptitle("Comparison of Framework Performance", fontsize=16)

for i, res in enumerate(results, 1):
    plt.subplot(3, 2, i)
    plt.plot(res["targets"], label="True Function", color="blue")
    plt.plot(res["predictions"], label="Predictions", color="red", linestyle="--")
    plt.title(f"{res['framework']} - {res['mode']}")
    plt.xlabel("Input")
    plt.ylabel("Output")
    plt.legend()
    plt.grid(True)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

# Display results comparison table
print("\n" + "=" * 120)
print("FRAMEWORK PERFORMANCE COMPARISON".center(120))
print("=" * 120)
print(
    f"{'Framework':<15} | {'Mode':<25} | {'Total Time (s)':<15} | {'Avg Epoch Time (ms)':<20} | {'Final Loss':<12} | {'Correlation':<12}"
)
print("-" * 120)

for res in results:
    print(
        f"{res['framework']:<15} | {res['mode']:<25} | {res['total_time']:<15.4f} | {res['avg_epoch_time'] * 1000:<20.4f} | {res['final_loss']:<12.6f} | {res['correlation']:<12.4f}"
    )

# Find best performance
best_loss = min(results, key=lambda x: x["final_loss"])
best_corr = max(results, key=lambda x: x["correlation"])
fastest = min(results, key=lambda x: x["avg_epoch_time"])

print("\n" + "=" * 120)
print("SUMMARY".center(120))
print("=" * 120)
print(
    f"Lowest final loss: {best_loss['framework']} - {best_loss['mode']} (Loss: {best_loss['final_loss']:.6f})"
)
print(
    f"Highest correlation: {best_corr['framework']} - {best_corr['mode']} (Correlation: {best_corr['correlation']:.4f})"
)
print(
    f"Fastest average epoch time: {fastest['framework']} - {fastest['mode']} ({fastest['avg_epoch_time'] * 1000:.4f} ms)"
)

Running Nabla Eager Mode Experiment...

==================================================
Starting Nabla Training with: Eager Mode (No JIT)
==================================================
Architecture: [1, 64, 128, 256, 128, 64, 1]
Initial Learning Rate: 0.001
Initial Loss: 0.191547
Starting training loop...
Epoch 1000 | Avg Loss: 0.112580
Epoch 2000 | Avg Loss: 0.087121
Epoch 3000 | Avg Loss: 0.075036
Epoch 4000 | Avg Loss: 0.067084
Epoch 5000 | Avg Loss: 0.053578

Running Nabla Dynamic JIT Experiment...

==================================================
Starting Nabla Training with: Dynamic JIT (nb.djit)
==================================================
Architecture: [1, 64, 128, 256, 128, 64, 1]
Initial Learning Rate: 0.001
Initial Loss: 0.191547
Starting training loop...
Epoch 1000 | Avg Loss: 0.112085
Epoch 2000 | Avg Loss: 0.087083
Epoch 3000 | Avg Loss: 0.076075
Epoch 4000 | Avg Loss: 0.061879
Epoch 5000 | Avg Loss: 0.045098

Running Nabla Static JIT Experiment...

==================================================
Starting Nabla Training with: Static JIT (nb.jit)
==================================================
Architecture: [1, 64, 128, 256, 128, 64, 1]
Initial Learning Rate: 0.001
Initial Loss: 0.191547
Starting training loop...
Epoch 1000 | Avg Loss: 0.111682
Epoch 2000 | Avg Loss: 0.088914
Epoch 3000 | Avg Loss: 0.073619
Epoch 4000 | Avg Loss: 0.060372
Epoch 5000 | Avg Loss: 0.044641

Running JAX Eager Mode Experiment...

==================================================
Starting JAX Training with: Eager Mode (No JIT)
==================================================
Architecture: [1, 64, 128, 256, 128, 64, 1]
Initial Learning Rate: 0.001
Initial Loss: 0.191547
Starting training loop...
Epoch 1000 | Avg Loss: 0.112762
Epoch 2000 | Avg Loss: 0.086168
Epoch 3000 | Avg Loss: 0.072306
Epoch 4000 | Avg Loss: 0.057749
Epoch 5000 | Avg Loss: 0.049398

Running JAX JIT Experiment...

==================================================
Starting JAX Training with: JIT
==================================================
Architecture: [1, 64, 128, 256, 128, 64, 1]
Initial Learning Rate: 0.001
Initial Loss: 0.191547
Starting training loop...
Epoch 1000 | Avg Loss: 0.112052
Epoch 2000 | Avg Loss: 0.090434
Epoch 3000 | Avg Loss: 0.075858
Epoch 4000 | Avg Loss: 0.060189
Epoch 5000 | Avg Loss: 0.048559
../_images/tutorials_jax_vs_nabla_mlp_cpu_12_1.png

========================================================================================================================
                                            FRAMEWORK PERFORMANCE COMPARISON
========================================================================================================================
Framework       | Mode                      | Total Time (s)  | Avg Epoch Time (ms)  | Final Loss   | Correlation
------------------------------------------------------------------------------------------------------------------------
Nabla           | Eager Mode (No JIT)       | 26.0870         | 5.2168               | 0.043524     | 0.8076
Nabla           | Dynamic JIT (nb.djit)     | 24.7653         | 4.9234               | 0.033746     | 0.8572
Nabla           | Static JIT (nb.jit)       | 3.2174          | 0.6185               | 0.029933     | 0.8727
JAX             | Eager Mode (No JIT)       | 41.2386         | 7.8462               | 0.042708     | 0.8132
JAX             | JIT                       | 3.6152          | 0.6939               | 0.043040     | 0.8100

========================================================================================================================
                                                        SUMMARY
========================================================================================================================
Lowest final loss: Nabla - Static JIT (nb.jit) (Loss: 0.029933)
Highest correlation: Nabla - Static JIT (nb.jit) (Correlation: 0.8727)
Fastest average epoch time: Nabla - Static JIT (nb.jit) (0.6185 ms)

Note

💡 Want to run this yourself?

  • 🚀 Google Colab: No setup required, runs in your browser

  • 📥 Local Jupyter: Download and run with your own Python environment