torch.multinomial¶
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor ¶
Returns a tensor where each row contains
num_samples
indices sampled from the multinomial (a stricter definition would be multivariate, refer to torch.distributions.multinomial.Multinomial for more details) probability distribution located in the corresponding row of tensorinput
.Note
The rows of
input
do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.Indices are ordered from left to right according to when each was sampled (first samples are placed in first column).
If
input
is a vector,out
is a vector of sizenum_samples
.If
input
is a matrix with m rows,out
is an matrix of shape .If replacement is
True
, samples are drawn with replacement.If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.
Note
When drawn without replacement,
num_samples
must be lower than number of non-zero elements ininput
(or the min number of non-zero elements in each row ofinput
if it is a matrix).- Parameters
- Keyword Arguments
generator (
torch.Generator
, optional) – a pseudorandom number generator for samplingout (Tensor, optional) – the output tensor.
Example:
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 4) # ERROR! RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False, not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320 >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])