Source code for nabla.core.execution_context
# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Thread-safe execution context for model caching."""
import threading
from collections.abc import Callable
from typing import Optional
from max.engine.api import Model
[docs]
class ThreadSafeExecutionContext:
"""Thread-safe wrapper around the global execution context dictionary."""
[docs]
def __init__(self) -> None:
self._cache: dict[int, Model] = {}
self._lock = threading.RLock() # Using RLock to allow recursive locking
[docs]
def get(self, key: int) -> Optional[Model]:
"""Get a model from the cache. Returns None if not found."""
with self._lock:
return self._cache.get(key)
[docs]
def set(self, key: int, model: Model) -> None:
"""Set a model in the cache."""
with self._lock:
self._cache[key] = model
[docs]
def get_or_create(self, key: int, factory: Callable[[], Model]) -> Model:
"""
Get a model from cache, or create it using the factory function if not found.
This is thread-safe and ensures only one thread creates the model for a given key.
"""
# First, try to get without holding lock for long
with self._lock:
if key in self._cache:
return self._cache[key]
# Model not found, need to create it
# Use double-checked locking pattern
with self._lock:
# Check again in case another thread created it while we were waiting
if key in self._cache:
return self._cache[key]
# Create the model (this might take time, but we need to hold the lock
# to prevent multiple threads from creating the same model)
model = factory()
self._cache[key] = model
return model
[docs]
def contains(self, key: int) -> bool:
"""Check if a key exists in the cache."""
with self._lock:
return key in self._cache
[docs]
def clear(self) -> None:
"""Clear the entire cache. Useful for testing or memory management."""
with self._lock:
self._cache.clear()
[docs]
def size(self) -> int:
"""Get the current size of the cache."""
with self._lock:
return len(self._cache)
# Global model cache with thread safety
global_execution_context = ThreadSafeExecutionContext()