当前位置: 首页 > news >正文

CSR和COO实现spgemm

1) COO实现spgemm 

#include <iostream>
#include <vector>
#include <unordered_map>// 定义稀疏矩阵的元素
struct Element {int row, col;double value;
};// 稀疏矩阵乘法函数
std::vector<Element> spgemm(const std::vector<Element>& A, const std::vector<Element>& B, int rowsA, int colsB) {//创建可自由缩放的二维矩阵数组std::unordered_map<int, std::unordered_map<int, double>> result;std::unordered_map<int, std::vector<Element>> B_col;for (const auto& elem : B) {B_col[elem.col].push_back(elem);// 将B矩阵转换为列主序, 为下面的判断做准备}// 进行乘法运算for (const auto& a : A) {//判断矩阵A的列在矩阵B中是否有对应的行相匹配if (B_col.find(a.col) != B_col.end()) {//a是向量A中的一个数据类型为element的元素, a.col代表元素element.col//B_col[a.col]: 确保矩阵B的行与矩阵A的列相匹配for (const auto& b : for (const auto& b : B_col[a.col]) {result[a.row][b.col] += a.value * b.value;}) {result[a.row][b.col] += a.value * b.value;}}}// 将结果转换为COO格式std::vector<Element> C;for (const auto& row : result) {for (const auto& col : row.second) {C.push_back({row.first, col.first, col.second});}}return C;
}int main() {// 示例稀疏矩阵A和Bstd::vector<Element> A = {{0, 0, 1.0}, {0, 1, 2.0}, {1, 0, 3.0}};std::vector<Element> B = {{0, 0, 4.0}, {1, 0, 5.0}, {1, 1, 6.0}};// 计算A * Bstd::vector<Element> C = spgemm(A, B, 2, 2);// 输出结果for (const auto& elem : C) {std::cout << "C(" << elem.row << ", " << elem.col << ") = " << elem.value << std::endl;}return 0;
}

2) CSR实现spgemm

#include <iostream>
#include <vector>
#include <unordered_map>
#include <stdexcept>
using namespace std;// CSR格式的稀疏矩阵
struct CSRMatrix {std::vector<int> row_ptr;  // 每行的起始索引std::vector<int> col_idx;  // 非零元素的列索引std::vector<double> values;  // 非零元素的值int rows, cols;  // 矩阵的行数和列数
};// SpGEMM算法实现
CSRMatrix spgemm(const CSRMatrix& A, const CSRMatrix& B) {// 检查矩阵整体维度是否匹配if (A.cols != B.rows) {throw std::invalid_argument("Matrix dimensions do not match for multiplication.");}CSRMatrix C;//初始化C矩阵的rows和colsC.rows = A.rows;C.cols = B.cols;//初始化c.row_ptr, 并将元素值都设为0C.row_ptr.resize(C.rows + 1, 0);// //定义元素个数为C.rows, 元素类型为字典的向量, 临时存储每行的非零元素// 实现了可以自动扩展大小的二级数组的效果std::vector<std::unordered_map<int, double>> temp(C.rows);// 遍历矩阵A的每一行for (int i = 0; i < A.rows; ++i) {// 遍历A的当前行的每一个非零元素//由于CSR存储格式的特征, row_ptr[i+1]-row_ptr[i]代表第i行元素的个数for (int j = A.row_ptr[i]; j < A.row_ptr[i + 1]; ++j) {int a_col = A.col_idx[j];  // A的列索引double a_val = A.values[j];  // A的值//B.row_ptr[a_col]: 找到与A的列对应的矩阵B的行for (int k = B.row_ptr[a_col]; k < B.row_ptr[a_col + 1]; ++k) {int b_col = B.col_idx[k];  // B的列索引, 与矩阵A中某一行相乘的矩阵B中的某一个列号double b_val = B.values[k];  // B的值temp[i][b_col] += a_val * b_val;  // 累加结果}}}// 将临时存储的结果转换为CSR格式for (int i = 0; i < C.rows; ++i) {for (const auto& pair : temp[i]) {C.col_idx.push_back(pair.first);  // 列索引C.values.push_back(pair.second);  // 值}C.row_ptr[i + 1] = C.col_idx.size();  // 更新行指针}return C;
}int main() {// 示例矩阵A的初始化CSRMatrix A = {{0, 2, 4},  // 行指针{0, 1, 0, 2},  // 列索引{1.0, 2.0, 3.0, 4.0},  // 值2, 3  // 行数和列数};// 示例矩阵B的初始化CSRMatrix B = {{0, 1, 3, 4},  // 行指针{0, 1, 2, 2},  // 列索引{5.0, 6.0, 7.0, 8.0},  // 值3, 3  // 行数和列数};// 计算矩阵C = A * BCSRMatrix C = spgemm(A, B);// 输出结果矩阵Cstd::cout << "C.row_ptr: ";for (int val : C.row_ptr) std::cout << val << " ";std::cout << "\nC.col_idx: ";for (int val : C.col_idx) std::cout << val << " ";std::cout << "\nC.values: ";for (double val : C.values) std::cout << val << " ";std::cout << std::endl;return 0;
}


http://www.mrgr.cn/news/14541.html

相关文章:

  • Java面试宝典-java基础02
  • docker网络+跨主机容器之间的通讯
  • 【Java】—— Java面向对象基础:Java中类的构造器与属性初始化,Student类的实例
  • Kafka 到数据仓库:使用 bend-ingest-kafka 将消息加载到 Databend
  • 【Hot100】LeetCode—79. 单词搜索
  • Linux查看jvm相关参数以及设置调优参数
  • MFC工控项目实例之八选择下拉菜单添加打钩图标
  • BaseCTF-Web-Week2-WP
  • linux修改文件的修改时间
  • 三天吃透Java面试八股文
  • 从0到1框架搭建,Python+Pytest+Allure+Git+Jenkins接口自动化框架(超细整理)
  • 不同搜索引擎蜘蛛的功能、‌抓取策略与技术实现差异探究
  • ArrayList 和 LinkedList 的区别?
  • Android 11.0 关于定制自适应AdaptiveIconDrawable类型的动态时钟图标的功能实现系列一
  • Redis中事务与乐观锁
  • 设计模式之建造者模式
  • 在VB.net中,LINQ有什么方法与属性
  • 代码随想录算法训练营第三十天|452. 用最少数量的箭引爆气球 435. 无重叠区间 763.划分字母区间
  • Prometheus监控Kubernetes ETCD
  • 这款SpringBoot+Vue酒店管理系统,你绝对值得拥有