Source code for nabla.utils.testing

# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Testing utilities for Nabla arrays."""

from typing import TYPE_CHECKING, Union

import numpy as np

# Import Array directly for runtime isinstance checks
from ..core.array import Array

if TYPE_CHECKING:
    pass  # Array already imported above


[docs] def allclose( a: Union[Array, np.ndarray, float, int], b: Union[Array, np.ndarray, float, int], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, ) -> bool: """ Returns True if two arrays are element-wise equal within a tolerance. This function automatically converts Nabla Arrays to numpy arrays using .to_numpy() before comparison, providing a convenient way to compare Nabla arrays with each other or with numpy arrays/scalars. Args: a: Input array or scalar b: Input array or scalar rtol: Relative tolerance parameter atol: Absolute tolerance parameter equal_nan: Whether to compare NaN's as equal Returns: bool: True if the arrays are equal within the given tolerance Examples: >>> import nabla as nb >>> a = nb.array([1.0, 2.0, 3.0]) >>> b = nb.array([1.0, 2.0, 3.000001]) >>> nb.allclose(a, b) True >>> nb.allclose(a, np.array([1.0, 2.0, 3.0])) True """ # Convert Nabla Arrays to numpy arrays if isinstance(a, Array): a = a.to_numpy() if isinstance(b, Array): b = b.to_numpy() # Use numpy's allclose for the actual comparison return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)