Skip to content

_squeeze

squeeze

Remove an axis from a tensor if that axis has length 1.

This method can be used with Numpy data:

n = np.array([[[[1],[2]]],[[[3],[4]]],[[[5],[6]]]])  # shape == (3, 1, 2, 1)
b = fe.backend.squeeze(n)  # [[1, 2], [3, 4], [5, 6]]
b = fe.backend.squeeze(n, axis=1)  # [[[1], [2]], [[3], [4]], [[5], [6]]]
b = fe.backend.squeeze(n, axis=3)  # [[[1, 2]], [[3, 4]], [[5, 6]]]

This method can be used with TensorFlow tensors:

t = tf.constant([[[[1],[2]]],[[[3],[4]]],[[[5],[6]]]])  # shape == (3, 1, 2, 1)
b = fe.backend.squeeze(t)  # [[1, 2], [3, 4], [5, 6]]
b = fe.backend.squeeze(t, axis=1)  # [[[1], [2]], [[3], [4]], [[5], [6]]]
b = fe.backend.squeeze(t, axis=3)  # [[[1, 2]], [[3, 4]], [[5, 6]]]

This method can be used with PyTorch tensors:

p = torch.tensor([[[[1],[2]]],[[[3],[4]]],[[[5],[6]]]])  # shape == (3, 1, 2, 1)
b = fe.backend.squeeze(p)  # [[1, 2], [3, 4], [5, 6]]
b = fe.backend.squeeze(p, axis=1)  # [[[1], [2]], [[3], [4]], [[5], [6]]]
b = fe.backend.squeeze(p, axis=3)  # [[[1, 2]], [[3, 4]], [[5, 6]]]

Parameters:

Name Type Description Default
tensor Tensor

The input value.

required
axis Optional[int]

Which axis to squeeze along, which must have length==1 (or pass None to squeeze all length 1 axes).

None

Returns:

Type Description
Tensor

The reshaped tensor.

Raises:

Type Description
ValueError

If tensor is an unacceptable data type.

Source code in fastestimator/fastestimator/backend/_squeeze.py
def squeeze(tensor: Tensor, axis: Optional[int] = None) -> Tensor:
    """Remove an `axis` from a `tensor` if that axis has length 1.

    This method can be used with Numpy data:
    ```python
    n = np.array([[[[1],[2]]],[[[3],[4]]],[[[5],[6]]]])  # shape == (3, 1, 2, 1)
    b = fe.backend.squeeze(n)  # [[1, 2], [3, 4], [5, 6]]
    b = fe.backend.squeeze(n, axis=1)  # [[[1], [2]], [[3], [4]], [[5], [6]]]
    b = fe.backend.squeeze(n, axis=3)  # [[[1, 2]], [[3, 4]], [[5, 6]]]
    ```

    This method can be used with TensorFlow tensors:
    ```python
    t = tf.constant([[[[1],[2]]],[[[3],[4]]],[[[5],[6]]]])  # shape == (3, 1, 2, 1)
    b = fe.backend.squeeze(t)  # [[1, 2], [3, 4], [5, 6]]
    b = fe.backend.squeeze(t, axis=1)  # [[[1], [2]], [[3], [4]], [[5], [6]]]
    b = fe.backend.squeeze(t, axis=3)  # [[[1, 2]], [[3, 4]], [[5, 6]]]
    ```

    This method can be used with PyTorch tensors:
    ```python
    p = torch.tensor([[[[1],[2]]],[[[3],[4]]],[[[5],[6]]]])  # shape == (3, 1, 2, 1)
    b = fe.backend.squeeze(p)  # [[1, 2], [3, 4], [5, 6]]
    b = fe.backend.squeeze(p, axis=1)  # [[[1], [2]], [[3], [4]], [[5], [6]]]
    b = fe.backend.squeeze(p, axis=3)  # [[[1, 2]], [[3, 4]], [[5, 6]]]
    ```

    Args:
        tensor: The input value.
        axis: Which axis to squeeze along, which must have length==1 (or pass None to squeeze all length 1 axes).

    Returns:
        The reshaped `tensor`.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    if tf.is_tensor(tensor):
        return tf.squeeze(tensor, axis=axis)
    elif isinstance(tensor, torch.Tensor):
        if axis is None:
            return torch.squeeze(tensor)
        else:
            return torch.squeeze(tensor, dim=axis)
    elif isinstance(tensor, np.ndarray):
        return np.squeeze(tensor, axis=axis)
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))