Flatten¶
-
class
torch.nn.
Flatten
(start_dim=1, end_dim=- 1)[source]¶ Flattens a contiguous range of dims into a tensor. For use with
Sequential
.- Shape:
Input: ,’ where is the size at dimension and means any number of dimensions including none.
Output: .
- Parameters
start_dim – first dim to flatten (default = 1).
end_dim – last dim to flatten (default = -1).
- Examples::
>>> input = torch.randn(32, 1, 5, 5) >>> # With default parameters >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # With non-default parameters >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5])