torch.mm¶
-
torch.
mm
(input, mat2, *, out=None) → Tensor¶ Performs a matrix multiplication of the matrices
input
andmat2
.If
input
is a tensor,mat2
is a tensor,out
will be a tensor.Note
This function does not broadcast. For broadcasting matrix products, see
torch.matmul()
.Supports strided and sparse 2-D tensors as inputs, autograd with respect to strided inputs.
This operator supports TensorFloat32.
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
- Parameters
- Keyword Arguments
out (Tensor, optional) – the output tensor.
Example:
>>> mat1 = torch.randn(2, 3) >>> mat2 = torch.randn(3, 3) >>> torch.mm(mat1, mat2) tensor([[ 0.4851, 0.5037, -0.3633], [-0.0760, -3.6705, 2.4784]])