Skip to content

_expand_dims

expand_dims

Create a new dimension in tensor along a given axis.

This method can be used with Numpy data:

n = np.array([2,7,5])
b = fe.backend.expand_dims(n, axis=0)  # [[2, 5, 7]]
b = fe.backend.expand_dims(n, axis=1)  # [[2], [5], [7]]

This method can be used with TensorFlow tensors:

t = tf.constant([2,7,5])
b = fe.backend.expand_dims(t, axis=0)  # [[2, 5, 7]]
b = fe.backend.expand_dims(t, axis=1)  # [[2], [5], [7]]

This method can be used with PyTorch tensors:

p = torch.tensor([2,7,5])
b = fe.backend.expand_dims(p, axis=0)  # [[2, 5, 7]]
b = fe.backend.expand_dims(p, axis=1)  # [[2], [5], [7]]

Parameters:

Name Type Description Default
tensor Tensor

The input to be modified, having n dimensions.

required
axis int

Which axis should the new axis be inserted along. Must be in the range [-n-1, n].

1

Returns:

Type Description
Tensor

A concatenated representation of the tensors, or None if the list of tensors was empty.

Raises:

Type Description
ValueError

If tensor is an unacceptable data type.

Source code in fastestimator/fastestimator/backend/_expand_dims.py
def expand_dims(tensor: Tensor, axis: int = 1) -> Tensor:
    """Create a new dimension in `tensor` along a given `axis`.

    This method can be used with Numpy data:
    ```python
    n = np.array([2,7,5])
    b = fe.backend.expand_dims(n, axis=0)  # [[2, 5, 7]]
    b = fe.backend.expand_dims(n, axis=1)  # [[2], [5], [7]]
    ```

    This method can be used with TensorFlow tensors:
    ```python
    t = tf.constant([2,7,5])
    b = fe.backend.expand_dims(t, axis=0)  # [[2, 5, 7]]
    b = fe.backend.expand_dims(t, axis=1)  # [[2], [5], [7]]
    ```

    This method can be used with PyTorch tensors:
    ```python
    p = torch.tensor([2,7,5])
    b = fe.backend.expand_dims(p, axis=0)  # [[2, 5, 7]]
    b = fe.backend.expand_dims(p, axis=1)  # [[2], [5], [7]]
    ```

    Args:
        tensor: The input to be modified, having n dimensions.
        axis: Which axis should the new axis be inserted along. Must be in the range [-n-1, n].

    Returns:
        A concatenated representation of the `tensors`, or None if the list of `tensors` was empty.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    if tf.is_tensor(tensor):
        return tf.expand_dims(tensor, axis=axis)
    elif isinstance(tensor, torch.Tensor):
        return torch.unsqueeze(tensor, dim=axis)
    elif isinstance(tensor, np.ndarray):
        return np.expand_dims(tensor, axis=axis)
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))