当前位置: 首页 > news >正文

c++ libtorch tensor 注意浅拷贝

错误代码示例

torch::Tensor multi_dim_identity = torch::zeros({ 2, 2, 2, 2 }, torch::kComplexDouble);
for (int i = 0; i < 2; ++i) {multi_dim_identity.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(), i, i }, 1);
}torch::Tensor all_Kx = multi_dim_identity;
torch::Tensor all_Ky = multi_dim_identity;for (int i = 0; i < 2; ++i) {torch::Tensor a = torch::zeros({ 2, 2, 2 }, torch::kComplexDouble);torch::Tensor b = torch::rand({ 2, 2, 2 }, torch::kComplexDouble);for (int j = 0; j < dim_x * dim_y; ++j) {all_Kx.index_put_({ i, torch::indexing::Slice(), j, j }, a.index({torch::indexing::Slice(), j, j}));all_Ky.index_put_({ i, torch::indexing::Slice(), j, j }, b.index({torch::indexing::Slice(), j, j}));
}}

结果 all_Kx和all_Ky一样,在每个第1维度上都是一样的随机b,因为all_Kx和all_Ky都是multi_dim_identity的浅拷贝,all_Kx先赋值,其实是赋值给了multi_dim_identity,然后all_Ky再赋值,其实是赋值给了multi_dim_identity,导致all_Kx也跟着变,所以和all_Ky一样

正确代码如下

torch::Tensor multi_dim_identity = torch::zeros({ 2, 2, 2, 2 }, torch::kComplexDouble);
for (int i = 0; i < 2; ++i) {multi_dim_identity.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(), i, i }, 1);
}torch::Tensor all_Kx = multi_dim_identity.clone();
torch::Tensor all_Ky = multi_dim_identity.clone();for (int i = 0; i < 2; ++i) {torch::Tensor a = torch::zeros({ 2, 2, 2 }, torch::kComplexDouble);torch::Tensor b = torch::rand({ 2, 2, 2 }, torch::kComplexDouble);for (int j = 0; j < dim_x * dim_y; ++j) {all_Kx.index_put_({ i, torch::indexing::Slice(), j, j }, a.index({torch::indexing::Slice(), j, j}));all_Ky.index_put_({ i, torch::indexing::Slice(), j, j }, b.index({torch::indexing::Slice(), j, j}));
}}

用clone方法可以深拷贝,这样all_Kx和all_Ky就不一样


http://www.mrgr.cn/news/48911.html

相关文章:

  • C++入门基础知识109—【关于C++ if 语句】
  • OAuth和OpenID Connect原理及认证实现的案例
  • Spring Boot 3 文件管理:上传、下载、预览、查询与删除(全网最全面教程)
  • R语言绘制线性回归图
  • 手写mybatis之解析和使用ResultMap映射参数配置
  • 架构师之路-学渣到学霸历程-11
  • 鸿蒙跨设备协同开发02——RichEditor跨设备获取文件
  • 八大排序--08快速排序
  • 34. 在排序数组中查找元素的第一个和最后一个位置
  • 【网易云音乐】--源代码分享
  • 太阳能电池特性及其应用
  • 24/10/12 算法笔记 NiN
  • Windows环境NodeJS下载配置安装运行
  • 进程与线程
  • 第3关:寻找两个等长有序序列的中位数
  • Linux内核 -- 编译之 Kconfig 字段解析
  • 【功能安全】什么是Aspice?
  • 【C++ 真题】B2078 含 k 个 3 的数
  • Apache Kafka基础认知-Part1
  • Python网络爬虫快速入门指南