当前位置: 首页 > news >正文

【深度学习】矩阵操作万能函数 einsum-爱因斯坦求和

ref:https://blog.csdn.net/zhaohongfei_358/article/details/125273126
在学习transformer的时候,看到代码里面有

        values = self.values(values)  # (N, value_len, embed_size)keys = self.keys(keys)  # (N, key_len, embed_size)queries = self.queries(query)  # (N, query_len, embed_size)# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = queries.reshape(N, query_len, self.heads, self.head_dim)# Einsum does matrix mult. for query*keys for each training example# with every other training example, don't be confused by einsum# it's just how I like doing matrix multiplication & bmmenergy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])# queries shape: (N, query_len, heads, heads_dim),# keys shape: (N, key_len, heads, heads_dim)# energy: (N, heads, query_len, key_len)

把我看蒙了,所以这次正经学习一下,看看咋回事。这个颇有一些只可意会不可言传的感觉,还是人菜瘾大,理解不深啊!

einsum 在numpy和torch中都有,借助了index–>(求和)

import torch
import torch.nn as nn
import torch.optim as optim
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape) # 矩阵乘法
print(torch.einsum('ij,kj->ki', x, v).shape) # 矩阵乘法 + T
print(torch.einsum('ij,km->ijkm', x, v).shape) # 这个算是一个拼接吧
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape)
print(torch.einsum('ij,kj->ki', x, v).shape)
print(torch.einsum('ij,km->ijkm', x, v).shape)
import torch
x = torch.tensor([[1, 2, 3],[4,5,6]])
y = torch.tensor([[7,8,9]])
x,y
(tensor([[1, 2, 3],[4, 5, 6]]),tensor([[7, 8, 9]]))
result = torch.einsum('ij,km->ijkm', x, y)
result
tensor([[[[ 7,  8,  9]],[[14, 16, 18]],[[21, 24, 27]]],[[[28, 32, 36]],[[35, 40, 45]],[[42, 48, 54]]]])
a = [[[1, 2],   # i=0[3, 4]],  # i=0[[5, 6],   # i=1[7, 8]]   #  i=1
]b = [[[9, 10, 11], #  i=0[12, 13, 14]], #  i=0[[15, 16, 17], # i=1[18, 19, 20]]  # i=1
]
torch.tensor(a[0]).shape,torch.tensor(b[0]).shape

torch.tensor(a[0]).shape,torch.tensor(b[0]).shape

torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[1]) @ torch.tensor(b[1])
tensor([[183, 194, 205],[249, 264, 279]])
res = []
for i in range(len(a)):a1 = torch.tensor(a[i])b1 = torch.tensor(b[i])res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
res = []
for i in range(len(a)):a1 = torch.tensor(a[i])b1 = torch.tensor(b[i])res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
x = torch.rand(3, 3)
torch.einsum('ii->i', x),x
(tensor([0.7127, 0.3843, 0.2046]),tensor([[0.7127, 0.0171, 0.9940],[0.6781, 0.3843, 0.9031],[0.4963, 0.1581, 0.2046]]))

http://www.mrgr.cn/news/43556.html

相关文章:

  • 如何使用CMD命令启动应用程序(二)
  • C0015.Clion中开发C++时,连接Mysql数据库方法
  • 【英特尔IA-32架构软件开发者开发手册第3卷:系统编程指南】2001年版翻译,1-1
  • 《python语言程序设计》2018版第8章19题几何Rectangle2D类(下)-头疼的几何和数学
  • 传感器模块编程实践(三)舵机+超声波模块融合DIY智能垃圾桶模型
  • 常见的基础系统
  • 今天学的Word小技巧——批量设置图片格式,批量让题注居中
  • 软考系统分析师知识点二:经济管理
  • 每日一道算法题——二分查找
  • clickhouse数据字典
  • 使用SpringBoot自定义注解+拦截器+token机制,实现接口的幂等性
  • 【go入门】流程控制语句
  • 51c视觉~CV~合集3
  • 基于Java的GeoTools对Shapefile文件属性信息深度解析
  • C语言进阶版第16课—自定义类型:结构体
  • webserver
  • 网络基础知识笔记(五)接口管理
  • 已解决-Nacos明明成功运行,但Spring报错连接不上
  • 《Linux从小白到高手》理论篇:一文概览常用Linux重要配置文件
  • 免费论文生成网站有哪些?推荐5款AI自动生成论文的网站