Create a switch/case operation, i.e. an integer-indexed conditional.
Aliases:
tf.compat.v1.switch_case
tf.compat.v2.switch_case
tf.switch_case(
branch_index,
branch_fns,
default=None,
name='switch_case'
)
tf.caseSee also .
tf.caseThis op can be substantially more efficient than when exactly one branch will be selected. tf.switch_case is more like a C++ switch/case statement than , which is more like an if/elif/elif/else chain.
The branch_fns
parameter is either a list of (int
, callable) pairs, or simply a list of callables (in which case the index is implicitly the key). The branch_index
Tensor
is used to select an element in branch_fns
with matching int
key, falling back to default
if none match, or max(keys) if no default
is provided. The keys must form a contiguous set from 0
to len(branch_fns
) - 1.
tf.switch_case supports nested structures as implemented in tf.nest. All callables must return the same (possibly nested) value structure of lists, tuples, and/or named tuples.
Example:
Pseudocode:
switch (branch_index) { // c-style switch
case 0: return 17;
case 1: return 31;
default: return -1;
}
or
branches = {0: lambda: 17, 1: lambda: 31}
branches.get(branch_index, lambda: -1)()
Expressions:
def f1(): return tf.constant(17)
def f2(): return tf.constant(31)
def f3(): return tf.constant(-1)
r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
# Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
Args:
branch_index
: An int Tensor specifying which ofbranch_fns
should be executed.branch_fns
: Alist
of (int, callable) pairs, or simply alist
of callables (in which case the index serves as the key). Each callable must return a matching structure of tensors.default
: Optional callable that returns a structure of tensors.name
: Aname
for this operation (optional).
Returns:
The tensors returned by the callable identified by branch_index
, or those returned by default
if no key matches and default
was provided, or those returned by the max-keyed branch_fn
if no default
is provided.
Raises:
TypeError
: Ifbranch_fns
is not a list/dictionary.TypeError
: Ifbranch_fns
is a list but does not contain 2-tuples or callables.TypeError
: Iffns[i]
is not callable for any i, ordefault
is not callable.
V2 Compatibility
branch_fns
could be a dictionary in v1. However, tf.Tensor and tf.Variable are no longer hashable in v2, so cannot be used as a key for a dictionary. Please use a list or a tuple instead.