Skip to content

_watch

watch

Monitor the given tensor for later gradient computations.

This method can be used with TensorFlow tensors:

x = tf.ones((3,28,28,1))
with tf.GradientTape(persistent=True) as tape:
    x = fe.backend.watch(x, tape=tape)

This method can be used with PyTorch tensors:

x = torch.ones((3,1,28,28))  # x.requires_grad == False
x = fe.backend.watch(x)  # x.requires_grad == True

Parameters:

Name Type Description Default
tensor Tensor

The tensor to be monitored.

required
tape Optional[GradientTape]

A TensorFlow GradientTape which will be used to record gradients (iff using TensorFlow for the backend).

None

Returns:

Type Description
Tensor

The tensor or a copy of the tensor which is being tracked for gradient computations. This value is only

Tensor

needed if using PyTorch as the backend.

Raises:

Type Description
ValueError

If tensor is an unacceptable data type.

Source code in fastestimator/fastestimator/backend/_watch.py
def watch(tensor: Tensor, tape: Optional[tf.GradientTape] = None) -> Tensor:
    """Monitor the given `tensor` for later gradient computations.

    This method can be used with TensorFlow tensors:
    ```python
    x = tf.ones((3,28,28,1))
    with tf.GradientTape(persistent=True) as tape:
        x = fe.backend.watch(x, tape=tape)
    ```

    This method can be used with PyTorch tensors:
    ```python
    x = torch.ones((3,1,28,28))  # x.requires_grad == False
    x = fe.backend.watch(x)  # x.requires_grad == True
    ```

    Args:
        tensor: The tensor to be monitored.
        tape: A TensorFlow GradientTape which will be used to record gradients (iff using TensorFlow for the backend).

    Returns:
        The `tensor` or a copy of the `tensor` which is being tracked for gradient computations. This value is only
        needed if using PyTorch as the backend.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    if tf.is_tensor(tensor):
        tape.watch(tensor)
        return tensor
    elif isinstance(tensor, torch.Tensor):
        if tensor.requires_grad:
            return tensor
        # It is tempting to just do tensor.requires_grad = True here, but that will lead to trouble
        return tensor.detach().requires_grad_(True)
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))