cond#
Signature#
nabla.cond(condition: 'Array', true_fn: 'Callable', false_fn: 'Callable') -> 'Array'
Description#
Conditionally executes one of two functions.
If condition
is True, true_fn
is called; otherwise, false_fn
is
called. This is a control-flow primitive that allows for conditional
execution within a computational graph. Unlike nabla.where
, which
evaluates both branches, cond
only executes the selected function.
Parameters#
condition
(Array
): A scalar boolean array that determines which function to execute.true_fn
(Callable
): The function to be called ifcondition
is True.false_fn
(Callable
): The function to be called ifcondition
is False.`` (
*args
): Positional arguments to be passed to the selected function.`` (
**kwargs
): Keyword arguments to be passed to the selected function.
Returns#
Array
: The result of calling eithertrue_fn
orfalse_fn
.
Examples#
>>> import nabla as nb
>>> def f(x):
return x * 2
...
>>> def g(x):
return x + 10
...
>>> x = nb.array(5)
>>> # Executes g(x) because the condition is False
>>> nb.cond(nb.array(False), f, g, x)
Array([15], dtype=int32)