当前位置: 首页 > news >正文

DeepSpeed笔记--利用Accelerate实现DeepSpeed加速

1--参考文档

Accelerate官方文档

accelerate+deepspeed多机多卡训练-适用集群环境

DeepSpeed & Accelerate

2--安装过程

# 安装accelerate
pip install accelerate

pip install importlib-metadata

# 获取默认配置文件
python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"

# 默认保存地址
# /home/liujinfu/.cache/huggingface/accelerate/default_config.yaml 

# 查看配好的环境
accelerate env

# 查看环境是否配好
accelerate test

3--测试代码

# 加载库
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoaderfrom accelerate import Accelerator, DeepSpeedPlugin# 定义测试网络
class TestNet(nn.Module):def __init__(self, input_dim: int, output_dim: int):super(TestNet, self).__init__()self.fc1 = nn.Linear(in_features = input_dim, out_features = output_dim)self.fc2 = nn.Linear(in_features = output_dim, out_features = output_dim)def forward(self, x: torch.Tensor):x = torch.relu(self.fc1(x))x = torch.fc2(x)return xif __name__ == "__main__":input_dim = 8output_dim = 64batch_size = 8dataset_size = 1000# 随机生成数据input_data = torch.randn(dataset_size, input_dim)labels = torch.randn(dataset_size, output_dim)# 创建数据集dataset = TensorDataset(input_data, labels)dataloader = DataLoader(dataset = dataset, batch_size = batch_size)# 初始化模型model = TestNet(input_dim = input_dim, output_dim = output_dim)# 创建Deepspeed配置deepspeed = DeepSpeedPlugin(zero_stage = 2, gradient_clipping = 1.0) # 使用zero-2accelerator = Accelerator(deepspeed_plugin = deepspeed)# 创建训练配置optimizator = torch.optim.Adam(model.parameters(), lr = 0.001)loss_func = nn.MSELoss()# 初始化model, optimizator, dataloader = accelerator.prepare(model, optimizator, dataloader)# 训练模型for epoch in range(10):model.train()for batch in dataloader:inputs, labels = batch# 清理梯度optimizator.zero_grad()outputs = model(inputs)loss = loss_func(outputs, labels)accelerator.backward(loss) # 核心改动optimizator.step()print(f"Epoch {epoch}, Loss: {loss.item()}")# 保存模型accelerator.wait_for_everyone()accelerator.save(model.state_dict(), "test_model.pth")

4--代码运行

未完待续!


http://www.mrgr.cn/news/38626.html

相关文章:

  • 零基础教你如何开发webman应用插件
  • 【数据挖掘】2023年 Quiz 1-3 整理 带答案
  • (十七)、Mac 安装k8s
  • Redis缓存双写一致性笔记(上)
  • 视频格式转换:avi格式转mp4格式
  • 盘点4款专业高效的数据恢复工具。
  • 基于SpringBoot+Vue+MySQL的甜品店管理系统
  • 衡石分析平台系统管理手册-功能配置之资源管理
  • MyBatis操作数据库(入门)
  • Elasticsearch学习笔记(2)
  • 【python】代码发布前检查- vulture:查找死代码
  • [CKA]CKA预约和考试
  • QT+ESP8266+STM32项目构建三部曲二--阿里云云端处理之云产品流转
  • Vue3动态导入后端路由
  • [SAP ABAP] SELECT-OPTIONS
  • 深入理解Shapefile格式:点、线、面要素的字节结构解析
  • 带动感的海报艺术!用前端技术点燃你的灵感
  • Cluster Explanation via Polyhedral Descriptions
  • 计算机毕业设计 助农产品采购平台的设计与实现 Java实战项目 附源码+文档+视频讲解
  • 【刷点笔试面试题试试水】不使用任何中间变量如何将a、b的值进行交换?