当前位置: 首页 > news >正文

【mmengine】优化器封装(OptimWrapper)(入门)优化器封装 vs 优化器

  • MMEngine 实现了优化器封装,为用户提供了统一的优化器访问接口。优化器封装支持不同的训练策略,包括混合精度训练、梯度累加和梯度截断。用户可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程,用户可以基于这一套流程,实现同一套代码,不同训练策略的切换。
  • 分别基于 Pytorch 内置的优化器和 MMEngine 的优化器封装(OptimWrapper)进行单精度训练混合精度训练梯度累加,对比二者实现上的区别。

一、 基于 Pytorch 的 SGD 优化器实现单精度训练

import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as Finputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()for input, target in zip(inputs, targets):output = model(input)loss = F.l1_loss(output, target)loss.backward()optimizer.step()optimizer.zero_grad()

二、 使用 MMEngine 的优化器封装实现单精度训练

from mmengine.optim import OptimWrapperoptim_wrapper = OptimWrapper(optimizer=optimizer)for input, target in zip(inputs, targets):output = model(input)loss = F.l1_loss(output, target)optim_wrapper.update_params(loss)

优化器封装的 update_params 实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。
在这里插入图片描述

三、 基于 Pytorch 的 SGD 优化器实现混合精度训练

在这里插入图片描述

  • 混合精度训练:单精度 float和半精度 float16 混合,其优势为:
    • 内存占用更少
    • 计算更快
from torch.cuda.amp import autocastmodel = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10for input, target in zip(inputs, targets):with autocast():output = model(input.cuda())loss = F.l1_loss(output, target.cuda())loss.backward()optimizer.step()optimizer.zero_grad()

四、 基于 MMEngine 的 优化器封装实现混合精度训练

from mmengine.optim import AmpOptimWrapperoptim_wrapper = AmpOptimWrapper(optimizer=optimizer)for input, target in zip(inputs, targets):with optim_wrapper.optim_context(model):output = model(input.cuda())loss = F.l1_loss(output, target.cuda())optim_wrapper.update_params(loss)

在这里插入图片描述

  • 混合精度训练需要使用 AmpOptimWrapper,他的 optim_context 接口类似 autocast,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍

五、 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加

for idx, (input, target) in enumerate(zip(inputs, targets)):with autocast():output = model(input.cuda())loss = F.l1_loss(output, target.cuda())loss.backward()if idx % 2 == 0:optimizer.step()optimizer.zero_grad()

六、基于 MMEngine 的优化器封装实现混合精度训练和梯度累加

optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)for input, target in zip(inputs, targets):with optim_wrapper.optim_context(model):output = model(input.cuda())loss = F.l1_loss(output, target.cuda())optim_wrapper.update_params(loss)

在这里插入图片描述
只需要配置 accumulative_counts 参数,并调用 update_params 接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 optim_wrapper 上下文,可以避免梯度累加阶段不必要的梯度同步。

七、 获取学习率/动量

优化器封装提供了 get_lrget_momentum 接口用于获取优化器的一个参数组的学习率:

import torch.nn as nn
from torch.optim import SGDfrom mmengine.optim import OptimWrappermodel = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)
# 封装器
optim_wrapper = OptimWrapper(optimizer)print("get info from optimizer ------")
print(optimizer.param_groups[0]['lr'])  # 0.01
print(optimizer.param_groups[0]['momentum'])  # 0
print("get info from wrapper ------")
print(optim_wrapper.get_lr())  # {'lr': [0.01]}
print(optim_wrapper.get_momentum())  # {'momentum': [0]}

在这里插入图片描述

八、 导出/加载状态字典

优化器封装和优化器一样,提供了 state_dictload_state_dict 接口,用于导出/加载优化器状态,对于 AmpOptimWrapper,优化器封装还会额外导出混合精度训练相关的参数:

import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper, AmpOptimWrappermodel = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)# ---- 导出 ---- #
print("print state_dict")
# 单精度封装器
optim_wrapper = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)# 导出状态字典
optim_state_dict = optim_wrapper.state_dict()
amp_optim_state_dict = amp_optim_wrapper.state_dict()
print(optim_state_dict)
print(amp_optim_state_dict)# ---- 加载 ---- #
print("load state_dict")
# 单精度封装器
optim_wrapper_new = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper_new = AmpOptimWrapper(optimizer=optimizer)# 加载状态字典
amp_optim_wrapper_new.load_state_dict(amp_optim_state_dict)
optim_wrapper_new.load_state_dict(optim_state_dict)

在这里插入图片描述

九、 使用多个优化器

OptimWrapperDict 的核心功能是支持批量导出/加载所有优化器封装的状态字典;支持获取多个优化器封装的学习率、动量。如果没有 OptimWrapperDict,MMEngine 就需要在很多位置对优化器封装的类型做 if else 判断,以获取所有优化器封装的状态。

from torch.optim import SGD
import torch.nn as nnfrom mmengine.optim import OptimWrapper, OptimWrapperDict# model1
gen = nn.Linear(1, 1)
# model2
disc = nn.Linear(1, 1)# optimizer1
optimizer_gen = SGD(gen.parameters(), lr=0.01)
# optimizer2
optimizer_disc = SGD(disc.parameters(), lr=0.01)# wrapper1
optim_wapper_gen = OptimWrapper(optimizer=optimizer_gen)
# wrapper2
optim_wapper_disc = OptimWrapper(optimizer=optimizer_disc)# wrapper_dict = wrapper1 + wrapper2
optim_dict = OptimWrapperDict(gen=optim_wapper_gen, disc=optim_wapper_disc)print("wrapper_dict = wrapper1 + wrapper2")
print(optim_dict.get_lr())  # {'gen.lr': [0.01], 'disc.lr': [0.01]}
print(optim_dict.get_momentum())  # {'gen.momentum': [0], 'disc.momentum': [0]}

在这里插入图片描述


http://www.mrgr.cn/news/41089.html

相关文章:

  • 试用模方软件时,在编辑模型视图下操作较卡,模型分辨率是3厘米,重建时设置平面划分的瓦片大小是450米,划分瓦片的时候大小设置多少比较合适?
  • 如何使用ssm实现基于JSP的高校听课评价系统
  • WPF下使用FreeRedis操作RedisStream实现简单的消息队列
  • 适用于 Windows 10 的最佳 PDF 编辑器列表,可帮助更改 PDF 文件。
  • ConcurrentLinkedQueue的核心方法有哪些?
  • 记一次炉石传说记牌器 Crash 排查经历
  • 【C++前缀和 状态压缩】1371. 每个元音包含偶数次的最长子字符串|2040
  • 解决银河麒麟服务器操作系统中`/etc/bashrc`环境变量不生效的问题
  • Python机器学习基础前置库学习:NumPy、Pandas、Matplotlib、Seaborn
  • Windows 11 的 24H2 更新將帶來全新 Copilot+ AI PC 功能
  • 从零到一构建解释器-【1-基础概念】
  • 【洛谷】P2357 守墓人 的题解
  • 编程参考 - 动态链接库中的变量实例化
  • 【C++】第二节:类与对象(上)
  • 如何使用ssm实现基于web的网站的设计与实现+vue
  • vulnhub-Replay 1靶机
  • SpringBoot实现的师生健康信息管理平台
  • 一本应用《软件方法》的书《软件需求分析和设计实践指南》
  • 单细胞scMetabolism代谢相关通路分析学习和整理
  • 提升工作效率的秘密武器大揭露