Skip to content

_concat

concat

Concatenate a list of tensors along a given axis.

This method can be used with Numpy data:

n = [np.array([[0, 1]]), np.array([[2, 3]]), np.array([[4, 5]])]
b = fe.backend.concat(n, axis=0)  # [[0, 1], [2, 3], [4, 5]]
b = fe.backend.concat(n, axis=1)  # [[0, 1, 2, 3, 4, 5]]

This method can be used with TensorFlow tensors:

t = [tf.constant([[0, 1]]), tf.constant([[2, 3]]), tf.constant([[4, 5]])]
b = fe.backend.concat(t, axis=0)  # [[0, 1], [2, 3], [4, 5]]
b = fe.backend.concat(t, axis=1)  # [[0, 1, 2, 3, 4, 5]]

This method can be used with PyTorch tensors:

p = [torch.tensor([[0, 1]]), torch.tensor([[2, 3]]), torch.tensor([[4, 5]])]
b = fe.backend.concat(p, axis=0)  # [[0, 1], [2, 3], [4, 5]]
b = fe.backend.concat(p, axis=1)  # [[0, 1, 2, 3, 4, 5]]

Parameters:

Name Type Description Default
tensors List[Tensor]

A list of tensors to be concatenated.

required
axis int

The axis along which to concatenate the input.

0

Returns:

Type Description
Optional[Tensor]

A concatenated representation of the tensors, or None if the list of tensors was empty.

Raises:

Type Description
ValueError

If tensors is an unacceptable data type.

Source code in fastestimator/fastestimator/backend/_concat.py
def concat(tensors: List[Tensor], axis: int = 0) -> Optional[Tensor]:
    """Concatenate a list of `tensors` along a given `axis`.

    This method can be used with Numpy data:
    ```python
    n = [np.array([[0, 1]]), np.array([[2, 3]]), np.array([[4, 5]])]
    b = fe.backend.concat(n, axis=0)  # [[0, 1], [2, 3], [4, 5]]
    b = fe.backend.concat(n, axis=1)  # [[0, 1, 2, 3, 4, 5]]
    ```

    This method can be used with TensorFlow tensors:
    ```python
    t = [tf.constant([[0, 1]]), tf.constant([[2, 3]]), tf.constant([[4, 5]])]
    b = fe.backend.concat(t, axis=0)  # [[0, 1], [2, 3], [4, 5]]
    b = fe.backend.concat(t, axis=1)  # [[0, 1, 2, 3, 4, 5]]
    ```

    This method can be used with PyTorch tensors:
    ```python
    p = [torch.tensor([[0, 1]]), torch.tensor([[2, 3]]), torch.tensor([[4, 5]])]
    b = fe.backend.concat(p, axis=0)  # [[0, 1], [2, 3], [4, 5]]
    b = fe.backend.concat(p, axis=1)  # [[0, 1, 2, 3, 4, 5]]
    ```

    Args:
        tensors: A list of tensors to be concatenated.
        axis: The axis along which to concatenate the input.

    Returns:
        A concatenated representation of the `tensors`, or None if the list of `tensors` was empty.

    Raises:
        ValueError: If `tensors` is an unacceptable data type.
    """
    if len(tensors) == 0:
        return None
    if tf.is_tensor(tensors[0]):
        return tf.concat(tensors, axis=axis)
    elif isinstance(tensors[0], torch.Tensor):
        return torch.cat(tensors, dim=axis)
    elif isinstance(tensors[0], np.ndarray):
        return np.concatenate(tensors, axis=axis)
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensors[0])))