torch.Tensor.scatter_reduce_¶
-
Tensor.
scatter_reduce_
(dim, index, src, reduce, *, include_self=True) → Tensor¶ Reduces all values from the
src
tensor to the indices specified in theindex
tensor in theself
tensor using the applied reduction defined via thereduce
argument ("sum"
,"prod"
,"mean"
,"amax"
,"amin"
). For each value insrc
, it is reduced to an index inself
which is specified by its index insrc
fordimension != dim
and by the corresponding value inindex
fordimension = dim
. Ifinclude_self="True"
, the values in theself
tensor are included in the reduction.self
,index
andsrc
should all have the same number of dimensions. It is also required thatindex.size(d) <= src.size(d)
for all dimensionsd
, and thatindex.size(d) <= self.size(d)
for all dimensionsd != dim
. Note thatindex
andsrc
do not broadcast.For a 3-D tensor with
reduce="sum"
andinclude_self=True
the output is given as:self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
Note
This operation may behave nondeterministically when given tensors on a CUDA device. See Reproducibility for more information.
Note
The backward pass is implemented only for
src.shape == index.shape
.Warning
This function is in beta and may change in the near future.
- Parameters
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to scatter and reduce.
src (Tensor) – the source elements to scatter and reduce
reduce (str) – the reduction operation to apply for non-unique indices (
"sum"
,"prod"
,"mean"
,"amax"
,"amin"
)include_self (bool) – whether elements from the
self
tensor are included in the reduction
Example:
>>> src = torch.tensor([1., 2., 3., 4., 5., 6.]) >>> index = torch.tensor([0, 1, 0, 1, 2, 1]) >>> input = torch.tensor([1., 2., 3., 4.]) >>> input.scatter_reduce(0, index, src, reduce="sum") tensor([5., 14., 8., 4.]) >>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False) tensor([4., 12., 5., 4.]) >>> input2 = torch.tensor([5., 4., 3., 2.]) >>> input2.scatter_reduce(0, index, src, reduce="amax") tensor([5., 6., 5., 2.]) >>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False) tensor([3., 6., 5., 2.])