Skip to content

_check_nan

check_nan

Checks if the input contains NaN values.

This method can be used with Numpy data:

n = np.array([[[1.0, 2.0], [3.0, np.NaN]], [[5.0, 6.0], [7.0, 8.0]]])
b = fe.backend.check_nan(n)  # True

This method can be used with TensorFlow tensors:

t = tf.constant([[[1.0, 2.0], [3.0, 4.0]], [[np.NaN, 6.0], [7.0, 8.0]]])
b = fe.backend.check_nan(n)  # True

This method can be used with PyTorch tensors:

p = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [np.NaN, 8.0]]])
b = fe.backend.check_nan(n)  # True

Parameters:

Name Type Description Default
val Union[int, float, ndarray, Tensor, Tensor]

The input value.

required

Returns:

Type Description
bool

True iff val contains NaN

Source code in fastestimator/fastestimator/backend/_check_nan.py
def check_nan(val: Union[int, float, np.ndarray, tf.Tensor, torch.Tensor]) -> bool:
    """Checks if the input contains NaN values.

    This method can be used with Numpy data:
    ```python
    n = np.array([[[1.0, 2.0], [3.0, np.NaN]], [[5.0, 6.0], [7.0, 8.0]]])
    b = fe.backend.check_nan(n)  # True
    ```

    This method can be used with TensorFlow tensors:
    ```python
    t = tf.constant([[[1.0, 2.0], [3.0, 4.0]], [[np.NaN, 6.0], [7.0, 8.0]]])
    b = fe.backend.check_nan(n)  # True
    ```

    This method can be used with PyTorch tensors:
    ```python
    p = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [np.NaN, 8.0]]])
    b = fe.backend.check_nan(n)  # True
    ```

    Args:
        val: The input value.

    Returns:
        True iff `val` contains NaN
    """
    if tf.is_tensor(val):
        return tf.reduce_any(tf.math.is_nan(val)) or tf.reduce_any(tf.math.is_inf(val))
    elif isinstance(val, torch.Tensor):
        return torch.isnan(val).any() or torch.isinf(val).any()
    else:
        return np.isnan(val).any() or np.isinf(val).any()