Source code for nabla.utils.max_interop
# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Shorthands for utitly functions from MAX"""
from max.driver import Device
from max.graph import DeviceRef
[docs]
def accelerator(device_id: int = 0) -> Device:
"""
Create an Accelerator device instance with the specified GPU ID.
Args:
device_id: GPU ID (default is 0)
Returns:
An instance of the Accelerator class for the specified GPU.
"""
from max.driver import Accelerator
return Accelerator(id=device_id)
[docs]
def cpu() -> Device:
"""
Create a CPU device instance.
Returns:
An instance of the CPU class.
"""
from max.driver import CPU
return CPU()
[docs]
def device(device_name: str) -> Device:
"""
Get a device instance based on the provided device name.
Args:
device_name: Name of the device (e.g., "cpu", "cuda", "mps")
Returns:
An instance of the corresponding Device class.
"""
# the name can sth like "gpu:0" or "gpu:1" or "cpu", so we need to extract the id from this string if gpu is part of it and apply it to the device like: accelerator(device_id=0) accelerator(device_id=1) or cpu()
if device_name.startswith("gpu"):
# Extract the GPU ID from the string
gpu_id = int(device_name.split(":")[1]) if ":" in device_name else 0
return accelerator(device_id=gpu_id)
elif device_name == "cpu":
return cpu()
else:
raise ValueError(
f"Unsupported device: {device_name}. Use 'cpu' or 'gpu:<id>' format."
)
[docs]
def device_ref(device: Device) -> DeviceRef:
"""
Get a DeviceRef instance for the specified device.
Args:
device: The Device instance to reference.
Returns:
A DeviceRef instance for the specified device.
"""
return DeviceRef.from_device(device)
[docs]
def accelerator_count() -> int:
"""
Get the number of available accelerators (GPUs).
Returns:
The number of available accelerators.
"""
from max.driver import accelerator_count
return accelerator_count()