当前位置: 首页 > news >正文

pytorch pyro更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵)

在机器学习和深度学习中,优化器是用来更新模型参数以最小化损失函数的算法。通常,优化器会计算损失函数相对于参数的一阶导数(梯度),然后根据这些梯度来更新参数。但是,更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵),来指导参数的更新

 

关于使用更高阶导数的优化器基类的描述。在机器学习和深度学习中,优化器是用来更新模型参数以最小化损失函数的算法。通常,优化器会计算损失函数相对于参数的一阶导数(梯度),然后根据这些梯度来更新参数。但是,更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵),来指导参数的更新。

这段描述中的关键点包括:

  1. 使用torch.autograd.grad而不是torch.Tensor.backwardtorch.autograd.grad是PyTorch中的一个函数,它可以用来计算张量相对于其他张量的导数。这与torch.Tensor.backward不同,后者是自动求导机制的一部分,通常用于计算梯度。

  2. 不同的接口:由于高阶优化器需要计算更高阶的导数,它们需要一个不同的接口。在这个接口中,step方法接受一个损失张量作为输入,并在优化器内部触发一次或多次反向传播。

  3. 派生类必须实现step方法:这意味着任何从这个基类派生的优化器类都需要提供自己的step方法实现,以计算导数并就地更新参数。

  4. 示例代码:示例展示了如何使用这种优化器。首先,通过poutine.trace获取模型的跟踪,然后计算负对数概率之和作为损失。接着,从跟踪中提取参数,并调用优化器的step方法来更新这些参数。

简而言之,这段代码描述了一个用于高级优化的基类,它允许开发者实现使用更高阶导数的自定义优化器。这种类型的优化器可能在某些情况下比传统的一阶优化器更有效,尤其是在参数更新需要更精细控制的场景中。

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0from typing import Dict, Listimport torchfrom pyro.ops.newton import newton_step
from pyro.optim.optim import PyroOptimclass MultiOptimizer:"""Base class of optimizers that make use of higher-order derivatives.Higher-order optimizers generally use :func:`torch.autograd.grad` ratherthan :meth:`torch.Tensor.backward`, and therefore require a differentinterface from usual Pyro and PyTorch optimizers. In this interface,the :meth:`step` method inputs a ``loss`` tensor to be differentiated,and backpropagation is triggered one or more times inside the optimizer.Derived classes must implement :meth:`step` to compute derivatives andupdate parameters in-place.Example::tr = poutine.trace(model).get_trace(*args, **kwargs)loss = -tr.log_prob_sum()params = {name: site['value'].unconstrained()for name, site in tr.nodes.items()if site['type'] == 'param'}optim.step(loss, params)"""def step(self, loss: torch.Tensor, params: Dict) -> None:"""Performs an in-place optimization step on parameters given adifferentiable ``loss`` tensor.Note that this detaches the updated tensors.:param torch.Tensor loss: A differentiable tensor to be minimized.Some optimizers require this to be differentiable multiple times.:param dict params: A dictionary mapping param name to unconstrainedvalue as stored in the param store."""updated_values = self.get_step(loss, params)for name, value in params.items():with torch.no_grad():# we need to detach because updated_value may depend on valuevalue.copy_(updated_values[name].detach())def get_step(self, loss: torch.Tensor, params: Dict) -> Dict:"""Computes an optimization step of parameters given a differentiable``loss`` tensor, returning the updated values.Note that this preserves derivatives on the updated tensors.:param torch.Tensor loss: A differentiable tensor to be minimized.Some optimizers require this to be differentiable multiple times.:param dict params: A dictionary mapping param name to unconstrainedvalue as stored in the param store.:return: A dictionary mapping param name to updated unconstrainedvalue.:rtype: dict"""raise NotImplementedErrorclass PyroMultiOptimizer(MultiOptimizer):"""Facade to wrap :class:`~pyro.optim.optim.PyroOptim` objectsin a :class:`MultiOptimizer` interface."""def __init__(self, optim: PyroOptim) -> None:if not isinstance(optim, PyroOptim):raise TypeError("Expected a PyroOptim object but got a {}".format(type(optim)))self.optim = optimdef step(self, loss: torch.Tensor, params: Dict) -> None:values = params.values()grads = torch.autograd.grad(loss, values, create_graph=True)  # type: ignorefor x, g in zip(values, grads):x.grad = gself.optim(values)class TorchMultiOptimizer(PyroMultiOptimizer):"""Facade to wrap :class:`~torch.optim.Optimizer` objectsin a :class:`MultiOptimizer` interface."""def __init__(self, optim_constructor: torch.optim.Optimizer, optim_args: Dict):optim = PyroOptim(optim_constructor, optim_args)super().__init__(optim)class MixedMultiOptimizer(MultiOptimizer):"""Container class to combine different :class:`MultiOptimizer` instances fordifferent parameters.:param list parts: A list of ``(names, optim)`` pairs, where each``names`` is a list of parameter names, and each ``optim`` is a:class:`MultiOptimizer` or :class:`~pyro.optim.optim.PyroOptim` objectto be used for the named parameters. Together the ``names`` shouldpartition up all desired parameters to optimize.:raises ValueError: if any name is optimized by multiple optimizers."""def __init__(self, parts: List) -> None:optim_dict: Dict = {}self.parts = []for names_part, optim in parts:if isinstance(optim, PyroOptim):optim = PyroMultiOptimizer(optim)for name in names_part:if name in optim_dict:raise ValueError("Attempted to optimize parameter '{}' by two different optimizers: ""{} vs {}".format(name, optim_dict[name], optim))optim_dict[name] = optimself.parts.append((names_part, optim))def step(self, loss: torch.Tensor, params: Dict):for names_part, optim in self.parts:optim.step(loss, {name: params[name] for name in names_part})def get_step(self, loss: torch.Tensor, params: Dict) -> Dict:updated_values = {}for names_part, optim in self.parts:updated_values.update(optim.get_step(loss, {name: params[name] for name in names_part}))return updated_valuesclass Newton(MultiOptimizer):"""Implementation of :class:`MultiOptimizer` that performs a Newton updateon batched low-dimensional variables, optionally regularizing via aper-parameter ``trust_radius``. See :func:`~pyro.ops.newton.newton_step`for details.The result of :meth:`get_step` will be differentiable, however theupdated values from :meth:`step` will be detached.:param dict trust_radii: a dict mapping parameter name to radius of trustregion. Missing names will use unregularized Newton update, equivalentto infinite trust radius."""def __init__(self, trust_radii: Dict = {}):self.trust_radii = trust_radiidef get_step(self, loss: torch.Tensor, params: Dict):updated_values = {}for name, value in params.items():trust_radius = self.trust_radii.get(name)  # type: ignoreupdated_value, cov = newton_step(loss, value, trust_radius)updated_values[name] = updated_valuereturn updated_values


http://www.mrgr.cn/news/19248.html

相关文章:

  • 【嵌入式撸码】内存相关的大小尽量偶数对齐
  • J.U.C Review - 阻塞队列原理/源码分析
  • https和harbor仓库跟k8s
  • Steam游戏截图方法
  • 如何判断字符串是否对称?
  • C语言 | Leetcode C语言题解之第394题字符串解码
  • Java中调用第三方接口
  • 语言桥梁:探索全球最受欢迎的翻译工具,让理解更简单
  • 3DMAX建筑魔术师MagicBuilding插件使用方法详解
  • jQuery基础——高级技巧
  • 软件测试方法之等价类测试
  • 【LVI-SAM】激光点云如何辅助视觉特征深度提取
  • 输入输出系统和中断总结
  • VUE3项目的几种创建方式
  • OpenCV从入门到精通——角点特征点提取匹配算法实战
  • ubuntu24和win11双系统,每次启动后Windows时间不正确的处理办法
  • TCP协议相关特性
  • Java并发编程之ThreadLocal深度探索
  • Linux 学习之路 - 信号的保存
  • 巧用xrename批量重命名下载的影视文件