Skip to content

_argmax

argmax

Compute the index of the maximum value along a given axis of a tensor.

This method can be used with Numpy data:

n = np.array([[2,7,5],[9,1,3],[4,8,2]])
b = fe.backend.argmax(n, axis=0)  # [1, 2, 0]
b = fe.backend.argmax(n, axis=1)  # [1, 0, 1]

This method can be used with TensorFlow tensors:

t = tf.constant([[2,7,5],[9,1,3],[4,8,2]])
b = fe.backend.argmax(t, axis=0)  # [1, 2, 0]
b = fe.backend.argmax(t, axis=1)  # [1, 0, 1]

This method can be used with PyTorch tensors:

p = torch.tensor([[2,7,5],[9,1,3],[4,8,2]])
b = fe.backend.argmax(p, axis=0)  # [1, 2, 0]
b = fe.backend.argmax(p, axis=1)  # [1, 0, 1]

Parameters:

Name Type Description Default
tensor Tensor

The input value.

required
axis int

Which axis to compute the index along.

0

Returns:

Type Description
Tensor

The indices corresponding to the maximum values within tensor along axis.

Raises:

Type Description
ValueError

If tensor is an unacceptable data type.

Source code in fastestimator/fastestimator/backend/_argmax.py
def argmax(tensor: Tensor, axis: int = 0) -> Tensor:
    """Compute the index of the maximum value along a given axis of a tensor.

    This method can be used with Numpy data:
    ```python
    n = np.array([[2,7,5],[9,1,3],[4,8,2]])
    b = fe.backend.argmax(n, axis=0)  # [1, 2, 0]
    b = fe.backend.argmax(n, axis=1)  # [1, 0, 1]
    ```

    This method can be used with TensorFlow tensors:
    ```python
    t = tf.constant([[2,7,5],[9,1,3],[4,8,2]])
    b = fe.backend.argmax(t, axis=0)  # [1, 2, 0]
    b = fe.backend.argmax(t, axis=1)  # [1, 0, 1]
    ```

    This method can be used with PyTorch tensors:
    ```python
    p = torch.tensor([[2,7,5],[9,1,3],[4,8,2]])
    b = fe.backend.argmax(p, axis=0)  # [1, 2, 0]
    b = fe.backend.argmax(p, axis=1)  # [1, 0, 1]
    ```

    Args:
        tensor: The input value.
        axis: Which axis to compute the index along.

    Returns:
        The indices corresponding to the maximum values within `tensor` along `axis`.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    if tf.is_tensor(tensor):
        return tf.argmax(tensor, axis=axis)
    elif isinstance(tensor, torch.Tensor):
        return tensor.max(dim=axis, keepdim=False)[1]
    elif isinstance(tensor, np.ndarray):
        return np.argmax(tensor, axis=axis)
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))