Shortcuts

torch.mm

torch.mm(input, mat2, *, out=None)Tensor

Performs a matrix multiplication of the matrices input and mat2.

If input is a (n×m)(n \times m) tensor, mat2 is a (m×p)(m \times p) tensor, out will be a (n×p)(n \times p) tensor.

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

Supports strided and sparse 2-D tensors as inputs, autograd with respect to strided inputs.

This operator supports TensorFloat32.

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Parameters
  • input (Tensor) – the first matrix to be matrix multiplied

  • mat2 (Tensor) – the second matrix to be matrix multiplied

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources