当前位置: 首页 > news >正文

PyTorch gather与scatter_详解

PyTorch gather与scatter_详解

在 PyTorch 常用的算子中,有两个理解巅峰的存在,那就是 torch.gathertorch.scatter_,在 Seq2SeqAttentioncrf viterbi等结构的源码中,都可以看到这两个算子的身影,今天来详细讲解一下这两个函数。

torch.gather

使用

torch.gather 函数用于从输入张量的指定维度收集元素。收集的索引由 index 张量提供。

使用语法:torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

核心参数

  • input:输入张量
  • dim:指定的维度
  • index:索引张量,包含收集元素的索引

注意

  • inputindex 必须要有相同的维度
  • 对于所有的 d != dim,都必须要有 index.size(d) <= input.size(d)以及out 的形状和 index形状相同
  • inputindex 之间没有广播机制
  • 只有在 src.shape == index.shape 时实现了反向传播
说明

以一个三维的张量为例

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

看一个示意图(这里index和dim都是从1开始,转换成代码时 -1 即可
img

再看一个示意图,应该懂了

  • dim=0
    img
dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]])
# 将 index 的 dim=0 处固定 然后其他位置按顺序填充
# [['0'-0, '1'-1, '2'-2],  ['1'-0, '2'-1, '0'-2]]
# [[(0, 0), (1, 1), (2, 2)], [(1, 0), (2, 1), (0, 2)]]output = torch.gather(input, dim, index)
# tensor([[10, 14, 18],
#         [13, 17, 12]])
  • dim=1
    img
dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]])
# 将 index 的 dim=1 处固定 然后其他位置按顺序填充
# [[0-'0', 0-'1'], [1-'1', 1-'2'], [2-'2', 2-'0']]
# [[(0, 0), (0, 1)], [(1, 1), (1, 2)], [(2, 2), (2, 0)]]output = torch.gather(input, dim, index)
# tensor([[10, 11],
#         [14, 15],
#         [18, 16]])
案例

假设我们有一个 2D 张量 data,我们希望根据索引张量 indexdata 中提取特定位置的值。

import torch# 创建一个 2D 张量 data
data = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
print("Data tensor:")
print(data)# 创建一个索引张量 index
index = torch.tensor([[0, 2],[1, 0],[2, 1]])
print("\nIndex tensor:")
print(index)# 使用 gather 函数
result = torch.gather(data, 1, index)
print("\nGathered result:")
print(result)

我们对上面案例进行逐步解释

  1. 初始张量 data:

    data = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
    

    这是一个 3x3 的张量:

    tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
    
  2. 索引张量 index:

    index = torch.tensor([[0, 2],[1, 0],[2, 1]])
    

    这是一个 3x2 的张量,表示要从 data 中提取的索引。

  3. 使用 gather 函数:

    result = torch.gather(data, 1, index)
    

    这个操作会根据 index 张量中的索引,从 data 张量中提取相应位置的值。具体操作如下:

    • 对于 data 的第 0 行:

      • index[0, 0] = 0,所以 result[0, 0] = data[0, 0] = 1
      • index[0, 1] = 2,所以 result[0, 1] = data[0, 2] = 3
    • 对于 data 的第 1 行:

      • index[1, 0] = 1,所以 result[1, 0] = data[1, 1] = 5
      • index[1, 1] = 0,所以 result[1, 1] = data[1, 0] = 4
    • 对于 data 的第 2 行:

      • index[2, 0] = 2,所以 result[2, 0] = data[2, 2] = 9
      • index[2, 1] = 1,所以 result[2, 1] = data[2, 1] = 8

最终,result 张量为:

tensor([[1, 3],[5, 4],[9, 8]])

torch.scatter_

使用

torch.scatter_ 是 PyTorch 中一个用于在特定维度上根据索引将值写入张量的原地操作函数。

使用语法:Tensor.scatter_(dim, index, src, *, reduce=None) → Tensor

核心参数

  • dim:指定沿着哪个维度进行散射操作
  • index:一个包含索引的张量,指定 src 中的值要写入 tensor 的位置
  • src:包含要写入 tensor 的值的张量

注意

  • self, indexsrc必须有相同的维度
  • 对于所有的维度 d 必须有 index.size(d) <= src.size(d)以及index.size(d) <= self.size(d)
  • indexsrc 不会进行广播
说明

torch.scatter_ 其实就是torch.gather 的一个逆运算

以一个三维的张量为例

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

另外,需要注意,scatter_ 是一个 inplace 算子

案例

先来看 dim=0 的情况

import torch
import numpy as np
src = torch.arange(1, 11).view(2, 5)
print(src)
> tensor([[ 1,  2,  3,  4,  5],[ 6,  7,  8,  9, 10]])input_tensor = torch.zeros(3, 5).long()
print(input_tensor)
> tensor([[0, 0, 0, 0, 0],[0, 0, 0, 0, 0],[0, 0, 0, 0, 0]])index_tensor = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
print(index_tensor)
> tensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]])## try to manually work out the result 
dim = 0
input_tensor.scatter_(dim, index_tensor, src)
print(input_tensor)
> ...
  • step1:将 src 的第1列分散到 input _tensor 的第1列。与指数张量的第1列相匹配。我们把1分散到0排,6分散到2排。

img

  • step2:将 src 的第2列分散到 input _ tensor 的第2列。与指数张量第2列匹配。我们把2分散到第1排,把7分散到第0排。

img

  • step3/4/5:以此类推,继续对其他列做散射。最后,我们将得到如下图。

img

运行代码,检查最终结果

> tensor([[ 1,  7,  8,  4,  5],[ 0,  2,  0,  9,  0],[ 6,  0,  3,  0, 10]])

再来看 dim=1 的情况

origin data

import torchsrc = torch.arange(1, 11).view(2, 5)
input_tensor = torch.zeros(3, 5).long()
index_tensor = torch.tensor([[3, 0, 2, 1, 4], [2, 0, 1, 3, 1]])
dim = 1
input_tensor.scatter_(dim, index_tensor, src)
print(input_tensor)
  • step1:将 src 的第一行散布到 input _ tensor 的第一行。1到 col3,2到 col0,3到 col2,4到 col1,5到 col4。

img

  • step2:将 src 的第2行散布到 input _ tensor 的第2行。

注意:index _ tensor 的第二行有两个1。为了使更新更清晰,我将这一步分为两个子步骤。

  • step2.1:分散6到 col2,7到 col0,8到 col1,9到 col3。

img

  • step2.2:对10进行分散,相应的索引是1,但是该位置8已经存在了,我们需要用10来覆盖8。

img

运行代码,检查最终结果

> tensor([[ 2,  4,  3,  1,  5],[ 7, 10,  6,  9,  0],[ 0,  0,  0,  0,  0]])

参考

PyTorch torch.gather

PyTorch torch.scatter_

What does gather() do in PyTorch

Understand torch.scatter_()


http://www.mrgr.cn/news/49187.html

相关文章:

  • 视频格式转换
  • Spring Boot知识管理:智能搜索与分析
  • 常用STL容器(c++)
  • 【畅捷通-注册安全分析报告】
  • 【python】内置装饰器-@property
  • unordered_set、unordered_multiset、unordered_map、unordered_multimap的介绍及使用
  • gitlab的基本用法之创建用户和组
  • 美国的社会分裂延续至AI领域
  • 02.07.链表相交 最简方法之一
  • 【解锁AI潜能:如何通过精确Prompt撰写引导智能对话】
  • JavaWeb 18.过滤器
  • 一文通透OpenAI o1:从CoT、Self-Correct/STaR、Self-play RL、MCST等技术细节到工程复现
  • 一些硬件知识【20241013】
  • 股市投资,如何应对人性挑战与把握关键策略?
  • 【JavaScript】Array的unshift的实现
  • 文件和目录的权限管理
  • Kaggle竞赛——森林覆盖类型分类
  • Android开发首页底部tab切换图标有动画效果
  • Linux——DNS服务器正向解析搭建教程
  • Java中的类加载器(ClassLoader)机制是怎样的?Java中的JNI(Java Native Interface)是什么?它是如何工作的?