当前位置: 首页 > news >正文

梯度检查点技术的使用

文章目录

  • 1、PyTorch
    • 使用例子1
    • 使用例子2
    • 使用例子3
  • 2、Hugging Face
    • 使用例子
  • 3、DeepSpeed

1、PyTorch

  • torch.utils.checkpoint官方文档
  • PyTorch Training Performance Guide 中关于 Gradient Checkpoints 的介绍
  • 参考博客

使用例子1

见博客。

使用例子2

import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriterimport numpy as npimport ostransform_train = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),    
])train_dataset = torchvision.datasets.CIFAR10("/home/zjma/dataset/cifar10/", train=True, transform=transform_train, download=False)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
test_dataset = torchvision.datasets.CIFAR10("/home/zjma/dataset/cifar10/", train=False, transform=transform_test, download=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)class CIFAR10Model_Original(nn.Module):def __init__(self):super().__init__()self.cnn_block_1 = nn.Sequential(*[nn.Conv2d(3, 32, 3, padding=1),nn.ReLU(),nn.Conv2d(32, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Dropout(0.25)])self.cnn_block_2 = nn.Sequential(*[nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Dropout(0.25)])self.flatten = lambda inp: torch.flatten(inp, 1)self.head = nn.Sequential(*[nn.Linear(64 * 8 * 8, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 10)])def forward(self, X):X = self.cnn_block_1(X)X = self.cnn_block_2(X)X = self.flatten(X)X = self.head(X)return Xclass CIFAR10Model_Optimized(nn.Module):def __init__(self):super().__init__()self.cnn_block_1 = nn.Sequential(*[nn.Conv2d(3, 32, 3, padding=1),nn.ReLU(),nn.Conv2d(32, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2)])self.dropout_1 = nn.Dropout(0.25)self.cnn_block_2 = nn.Sequential(*[nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2)])self.dropout_2 = nn.Dropout(0.25)self.flatten = lambda inp: torch.flatten(inp, 1)self.linearize = nn.Sequential(*[nn.Linear(64 * 8 * 8, 512),nn.ReLU()])self.dropout_3 = nn.Dropout(0.5)self.out = nn.Linear(512, 10)def forward(self, X):X = torch.utils.checkpoint.checkpoint(self.cnn_block_1, X)X = self.dropout_1(X)X = torch.utils.checkpoint.checkpoint(self.cnn_block_2, X)X = self.dropout_2(X)X = self.flatten(X)X = self.linearize(X)X = self.dropout_3(X)X = self.out(X)return X# clf = CIFAR10Model_Original()
clf = CIFAR10Model_Optimized()start_epoch = 1clf.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(clf.parameters(), lr=0.0001, weight_decay=1e-6)def train():clf.train()NUM_EPOCHS = 10for epoch in range(start_epoch, NUM_EPOCHS + 1):losses = []for i, (X_batch, y_cls) in enumerate(train_dataloader):optimizer.zero_grad()y = y_cls.cuda()X_batch = X_batch.cuda()y_pred = clf(X_batch)loss = criterion(y_pred, y)loss.backward()optimizer.step()train_loss = loss.item()losses.append(train_loss)# Memory statistics after each batchmem_allocated = torch.cuda.memory_allocated()mem_reserved = torch.cuda.memory_reserved()max_mem_allocated = torch.cuda.max_memory_allocated()max_mem_reserved = torch.cuda.max_memory_reserved()if i % 10 == 0:print(f'Finished epoch {epoch}/{NUM_EPOCHS}, batch {i}. loss: {train_loss:.3f}.'f'Memory allocated: {mem_allocated / (1024 ** 2):.2f} MB, 'f'Memory reserved: {mem_reserved / (1024 ** 2):.2f} MB, 'f'Max memory allocated: {max_mem_allocated / (1024 ** 2):.2f} MB, 'f'Max memory reserved: {max_mem_reserved / (1024 ** 2):.2f} MB.')# Reset peak memory stats for the next iterationtorch.cuda.reset_peak_memory_stats()print(f'Finished epoch {epoch}. 'f'avg loss: {np.mean(losses)}; median loss: {np.median(losses)}')train()

使用 checkpoint 优化前:

  • Max memory allocated: 69.58 MB
  • Max memory reserved: 96.00 MB

使用 checkpoint 优化后:

  • Max memory allocated: 40.80 MB
  • Max memory reserved: 64.00 MB

使用例子3

见项目。

2、Hugging Face

  • gradient_checkpointing_enable官方文档及用法
  • Methods and tools for efficient training on a single GPU 中关于 Gradient Checkpointing 的内容
  • Performance and Scalability: How To Fit a Bigger Model and Train It Faster中关于Gradient Checkpointing的内容
  • 参考博客1
  • 参考博客2

使用例子

见项目。

3、DeepSpeed

  • 参考博客及代码:[ deepSpeed ] 单机单卡本地运行 & Docker运行

  • https://eanyang7.github.io/transformers_docs/main_classes/deepspeed/#activation-checkpointing-gradient-checkpointing

  • https://zhuanlan.zhihu.com/p/644656141

  • https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one#deepspeed-zero


http://www.mrgr.cn/news/41061.html

相关文章:

  • Java中的自动重试机制:如何实现幂等与错误恢复
  • VS Code 图形化合并工具
  • 算法笔记(四)——模拟
  • 天呐!关于PyCharm你竟然一无所知?
  • [Linux]开发环境搭建
  • (笔记)第三期书生·浦语大模型实战营(十一卷王场)--书生入门岛通关第2关Python 基础知识
  • DAY84服务攻防-端口协议桌面应用QQWPS 等 RCEhydra 口令猜解未授权检测
  • Yocto - 使用Yocto开发嵌入式Linux系统_05 认识Bitbake工具
  • 计算机视觉算法:全面深入的探索与应用
  • 【内存池】——解决传统内存分配的弊端
  • 王道数据结构代码讲解
  • 一文彻底搞懂多模态 - 基础术语+基础知识+多模态学习
  • 网页前端开发之Javascript入门篇(3/9):条件控制
  • 操作系统错题解析【软考】
  • [MAUI]数据绑定和MVVM:MVVM的属性验证
  • 2024 全新体验:国学心理 API 接口来袭
  • 交换机如何开启FTP服务
  • 电商店铺多开自动回复软件
  • 【递归】11. leetcode 129 求根节点到叶节点数字之和
  • 高效论文写作指南:那些你必须知道的工具与平台