Skip to content

_get_image_dims

get_image_dims

Get the tensor channels, height, and width.

This method can be used with Numpy data:

n = np.random.random((2, 12, 12, 3))
b = fe.backend.get_image_dims(n)  # (3, 12, 12)

This method can be used with TensorFlow tensors:

t = tf.random.uniform((2, 12, 12, 3))
b = fe.backend.get_image_dims(t)  # (3, 12, 12)

This method can be used with PyTorch tensors:

p = torch.rand((2, 3, 12, 12))
b = fe.backend.get_image_dims(p)  # (3, 12, 12)

Parameters:

Name Type Description Default
tensor Tensor

The input tensor.

required

Returns:

Type Description
Tuple[int, int, int]

Channels, height and width of the tensor.

Raises:

Type Description
ValueError

If tensor is an unacceptable data type.

Source code in fastestimator/fastestimator/backend/_get_image_dims.py
def get_image_dims(tensor: Tensor) -> Tuple[int, int, int]:
    """Get the `tensor` channels, height, and width.

    This method can be used with Numpy data:
    ```python
    n = np.random.random((2, 12, 12, 3))
    b = fe.backend.get_image_dims(n)  # (3, 12, 12)
    ```

    This method can be used with TensorFlow tensors:
    ```python
    t = tf.random.uniform((2, 12, 12, 3))
    b = fe.backend.get_image_dims(t)  # (3, 12, 12)
    ```

    This method can be used with PyTorch tensors:
    ```python
    p = torch.rand((2, 3, 12, 12))
    b = fe.backend.get_image_dims(p)  # (3, 12, 12)
    ```

    Args:
        tensor: The input tensor.

    Returns:
        Channels, height and width of the `tensor`.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    assert len(tensor.shape) == 3 or len(tensor.shape) == 4, \
        f"Number of dimensions of input must be either 3 or 4, but found {len(tensor.shape)} (shape: {tensor.shape})"
    if tf.is_tensor(tensor):
        shape = tf.shape(tensor)
        channels, height, width = shape[-1], shape[-3], shape[-2]
        if hasattr(channels, 'numpy'):
            # Running in eager mode, so can convert to integer
            channels, height, width = channels.numpy().item(), height.numpy().item(), width.numpy().item()
        return channels, height, width
    elif isinstance(tensor, np.ndarray):
        return tensor.shape[-1], tensor.shape[-3], tensor.shape[-2]
    elif isinstance(tensor, torch.Tensor):
        return tensor.shape[-3], tensor.shape[-2], tensor.shape[-1]
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))