Create a case operation.
Aliases:
tf.compat.v2.case
tf.case(
pred_fn_pairs,
default=None,
exclusive=False,
strict=False,
name='case'
)
tf.switch_caseSee also .
The pred_fn_pairs
parameter is a list of pairs of size N. Each pair contains a boolean scalar tensor and a python callable that creates the tensors to be returned if the boolean evaluates to True. default
is a callable generating a list of tensors. All the callables in pred_fn_pairs
as well as default
(if provided) should return the same number and types of tensors.
If exclusive==True
, all predicates are evaluated, and an exception is thrown if more than one of the predicates evaluates to True
. If exclusive==False
, execution stops at the first predicate which evaluates to True
, and the tensors generated by the corresponding function are returned immediately. If none of the predicates evaluate to True
, this operation returns the tensors generated by default
.
tf.case supports nested structures as implemented in tf.contrib.framework.nest. All of the callables must return the same (possibly nested) value structure of lists, tuples, and/or named tuples. Singleton lists and tuples form the only exceptions to this: when returned by a callable, they are implicitly unpacked to single values. This behavior is disabled by passing strict=True.
Example 1:
Pseudocode:
if (x < y) return 17;
else return 23;
Expressions:
f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
r = tf.case([(tf.less(x, y), f1)], default=f2)
Example 2:
Pseudocode:
if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;
Expressions:
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
default=f3, exclusive=True)
Args:
pred_fn_pairs
: List of pairs of a boolean scalar tensor and a callable which returns a list of tensors.default
: Optional callable that returns a list of tensors.exclusive
:True
iff at most one predicate is allowed to evaluate toTrue
.strict
: A boolean that enables/disables 'strict
' mode; see above.name
: Aname
for this operation (optional).
Returns:
The tensors returned by the first pair whose predicate evaluated to True, or those returned by default
if none does.
Raises:
TypeError
: Ifpred_fn_pairs
is not a list/tuple.TypeError
: Ifpred_fn_pairs
is a list but does not contain 2-tuples.TypeError
: Iffns[i]
is not callable for any i, ordefault
is not callable.
V2 Compatibility
pred_fn_pairs
could be a dictionary in v1. However, tf.Tensor and tf.Variable are no longer hashable in v2, so cannot be used as a key for a dictionary. Please use a list or a tuple instead.