Skip to content

_convert_tensor_precision

convert_tensor_precision

Adjust the input data precision based of environment precision.

Parameters:

Name Type Description Default
tensor Tensor

The input value.

required

Returns:

Type Description
Tensor

The precision adjusted data(16 bit for mixed precision, 32 bit otherwise).

Source code in fastestimator/fastestimator/backend/_convert_tensor_precision.py
def convert_tensor_precision(tensor: Tensor) -> Tensor:
    """
        Adjust the input data precision based of environment precision.

        Args:
            tensor: The input value.

        Returns:
            The precision adjusted data(16 bit for mixed precision, 32 bit otherwise).

    """
    precision = 'float32'

    if mixed_precision.global_policy().compute_dtype == 'float16':
        precision = 'float16'

    return cast(tensor, precision)