Source code for nabla.core.execution_context

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Thread-safe execution context for model caching."""

import threading
from collections.abc import Callable
from typing import Optional

from max.engine.api import Model


[docs] class ThreadSafeExecutionContext: """Thread-safe wrapper around the global execution context dictionary."""
[docs] def __init__(self) -> None: self._cache: dict[int, Model] = {} self._lock = threading.RLock() # Using RLock to allow recursive locking
[docs] def get(self, key: int) -> Optional[Model]: """Get a model from the cache. Returns None if not found.""" with self._lock: return self._cache.get(key)
[docs] def set(self, key: int, model: Model) -> None: """Set a model in the cache.""" with self._lock: self._cache[key] = model
[docs] def get_or_create(self, key: int, factory: Callable[[], Model]) -> Model: """ Get a model from cache, or create it using the factory function if not found. This is thread-safe and ensures only one thread creates the model for a given key. """ # First, try to get without holding lock for long with self._lock: if key in self._cache: return self._cache[key] # Model not found, need to create it # Use double-checked locking pattern with self._lock: # Check again in case another thread created it while we were waiting if key in self._cache: return self._cache[key] # Create the model (this might take time, but we need to hold the lock # to prevent multiple threads from creating the same model) model = factory() self._cache[key] = model return model
[docs] def contains(self, key: int) -> bool: """Check if a key exists in the cache.""" with self._lock: return key in self._cache
[docs] def clear(self) -> None: """Clear the entire cache. Useful for testing or memory management.""" with self._lock: self._cache.clear()
[docs] def size(self) -> int: """Get the current size of the cache.""" with self._lock: return len(self._cache)
# Global model cache with thread safety global_execution_context = ThreadSafeExecutionContext()