Skip to content

_matmul

matmul

Perform matrix multiplication on a and b.

This method can be used with Numpy data:

a = np.array([[0,1,2],[3,4,5]])
b = np.array([[1],[2],[3]])
c = fe.backend.matmul(a, b)  # [[8], [26]]

This method can be used with TensorFlow tensors:

a = tf.constant([[0,1,2],[3,4,5]])
b = tf.constant([[1],[2],[3]])
c = fe.backend.matmul(a, b)  # [[8], [26]]

This method can be used with PyTorch tensors:

a = torch.tensor([[0,1,2],[3,4,5]])
b = torch.tensor([[1],[2],[3]])
c = fe.backend.matmul(a, b)  # [[8], [26]]

Parameters:

Name Type Description Default
a Tensor

The first matrix.

required
b Tensor

The second matrix.

required

Returns:

Type Description
Tensor

The matrix multiplication result of a * b.

Raises:

Type Description
ValueError

If either a or b are unacceptable or non-matching data types.

Source code in fastestimator/fastestimator/backend/_matmul.py
def matmul(a: Tensor, b: Tensor) -> Tensor:
    """Perform matrix multiplication on `a` and `b`.

    This method can be used with Numpy data:
    ```python
    a = np.array([[0,1,2],[3,4,5]])
    b = np.array([[1],[2],[3]])
    c = fe.backend.matmul(a, b)  # [[8], [26]]
    ```

    This method can be used with TensorFlow tensors:
    ```python
    a = tf.constant([[0,1,2],[3,4,5]])
    b = tf.constant([[1],[2],[3]])
    c = fe.backend.matmul(a, b)  # [[8], [26]]
    ```

    This method can be used with PyTorch tensors:
    ```python
    a = torch.tensor([[0,1,2],[3,4,5]])
    b = torch.tensor([[1],[2],[3]])
    c = fe.backend.matmul(a, b)  # [[8], [26]]
    ```

    Args:
        a: The first matrix.
        b: The second matrix.

    Returns:
        The matrix multiplication result of a * b.

    Raises:
        ValueError: If either `a` or `b` are unacceptable or non-matching data types.
    """
    if tf.is_tensor(a) and tf.is_tensor(b):
        return tf.matmul(a, b)
    elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
        return a.matmul(b)
    elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        return np.matmul(a, b)
    elif type(a) != type(b):
        raise ValueError(f"Tensor types do not match ({type(a)} and {type(b)})")
    else:
        raise ValueError(f"Unrecognized tensor type ({type(a)} or {type(b)})")