Class RegisterGradient
A decorator for registering the gradient function for an op type.
Aliases:
- Class
tf.compat.v1.RegisterGradient - Class
tf.compat.v2.RegisterGradientThis decorator is only used whendefining anew op type. For anop withminputs andnoutputs, the gradient functionis a functionthat takes the originalOperationandnTensorobjects (representing the gradients with respect to each output of the op), and returnsmTensorobjects (representing the partial gradients with respect to each input of the op). For example, assuming that operations of type"Sub"take two inputsxandy, and return a single outputx-y, the following gradient function would be registered:
@tf.RegisterGradient("Sub")
def _sub_grad(unused_op, grad):
return grad, tf.negative(grad)
The decorator argument op_type is the string type of an operation. This corresponds to the OpDef.name field for the proto that defines the operation.
init
__init__(op_type)
Creates a new decorator with op_type as the Operation type.
Args:
op_type: The string type of an operation. This corresponds to theOpDef.namefield for the proto that defines the operation.
Raises:
TypeError: Ifop_typeis not string.
Methods
call
__call__(f)
Registers the function f as gradient function for op_type.