当前位置: 首页 > news >正文

深度学习--------------------长短期记忆网络(LSTM)

目录

  • 长短期记忆网络
    • 候选记忆单元
    • 记忆单元
    • 隐状态
  • 长短期记忆网络代码从零实现
    • 初始化模型参数
    • 初始化
    • 实际模型
    • 训练
  • 简洁实现

长短期记忆网络

忘记门:将值朝0减少
输入门:决定要不要忽略掉输入数据
输出门:决定要不要使用隐状态。



在这里插入图片描述

在这里插入图片描述




候选记忆单元

在这里插入图片描述




记忆单元

记忆单元会把上一个时刻的记忆单元作为状态放进来,所以LSTM和RNN跟GRU不一样的地方是它的状态里面有两个独立的。
如果: F t F_t Ft等于0的话,就是希望不要记住 C t − 1 C_{t-1} Ct1
如果: I t I_t It是1的话,就是希望尽量的去用它,如果 I t I_t It等于0的话,就是把现在的记忆单元丢掉。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述




隐状态

在这里插入图片描述

在这里插入图片描述


在这里插入图片描述




长短期记忆网络代码从零实现

import torch
from torch import nn
from d2l import torch as d2l# 设置批量大小为32,时间步数为35
batch_size, num_steps = 32, 35
# 使用d2l库中的load_data_time_machine函数加载时间机器数据集,
# 并设置批量大小为32,时间步数为35,将加载的数据集赋值给train_iter和vocab变量
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)



初始化模型参数

def get_lstm_params(vocab_size, num_hiddens, device):# 将词汇表大小赋值给num_inputs和num_outputsnum_inputs = num_outputs = vocab_size# 定义一个辅助函数normal,用于生成具有特定形状的正态分布随机数,并将其初始化为较小的值def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 定义一个辅助函数three,用于生成三个参数:输入到隐藏状态的权重矩阵、隐藏状态到隐藏状态的权重矩阵和隐藏状态的偏置项def three():return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))# 调用three函数获取输入到隐藏状态的权重矩阵W_xi、隐藏状态到隐藏状态的权重矩阵W_hi和隐藏状态的偏置项b_iW_xi, W_hi, b_i = three()# 调用three函数获取输入到隐藏状态的权重矩阵W_xf、隐藏状态到隐藏状态的权重矩阵W_hf和隐藏状态的偏置项b_fW_xf, W_hf, b_f = three()# 调用three函数获取输入到隐藏状态的权重矩阵W_xo、隐藏状态到隐藏状态的权重矩阵W_ho和隐藏状态的偏置项b_oW_xo, W_ho, b_o = three()# 调用three函数获取输入到隐藏状态的权重矩阵W_xc、隐藏状态到隐藏状态的权重矩阵W_hc和隐藏状态的偏置项b_cW_xc, W_hc, b_c = three()# 生成隐藏状态到输出的权重矩阵W_hqW_hq = normal((num_hiddens, num_outputs))# 生成输出的偏置项b_qb_q  = torch.zeros(num_outputs, device=device)# 将所有参数组合成列表paramsparams = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]    # 变量所有参数for param in params:# 将所有参数的requires_grad属性设置为True,表示需要计算梯度param.requires_grad_(True)# 返回所有参数return params



初始化

def init_lstm_state(batch_size, num_hiddens, device):# 返回一个元组,包含两个张量:一个全零张量表示初始的隐藏状态(即:H要有个初始化),和一个全零张量表示初始的记忆细胞状态(即:C要有个初始化)。return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))



实际模型

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params# 解包状态元组state,分别赋值给隐藏状态H和记忆细胞状态C(H, C) = state# 创建一个空列表用于存储每个时间步的输出outputs = []# 对于输入序列中的每个时间步for X in inputs:# 输入门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算输入门I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)# 遗忘门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算遗忘门F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)# 输出门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算输出门O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)# 新的记忆细胞候选值的计算:使用输入、隐藏状态和偏置项,通过线性变换和tanh函数计算新的记忆细胞候选值C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)# 更新记忆细胞状态:将旧的记忆细胞状态与遗忘门和输入门的乘积相加,再与新的记忆细胞候选值的乘积相加,得到新的记忆细胞状态C = F * C + I * C_tilda# 更新隐藏状态:将输出门和经过tanh函数处理的记忆细胞状态的乘积作为新的隐藏状态H = O * torch.tanh(C)# 输出的计算:使用新的隐藏状态和偏置项,通过线性变换得到输出Y = (H @ W_hq) + b_q# 将当前时间步的输出添加到列表中outputs.append(Y)# 将所有时间步的输出在维度0上拼接起来,作为最终的输出结果;# 返回最终的输出结果和更新后的隐藏状态和记忆细胞状态的元组return torch.cat(outputs, dim=0), (H, C)



训练

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 使用d2l库中的RNNModelScratch类创建一个基于LSTM的模型对象,
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

在这里插入图片描述

在这里插入图片描述




简洁实现

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
# 使用nn.LSTM创建一个LSTM层,输入特征数量为num_inputs,隐藏单元数量为num_hiddens
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
# 使用d2l库中的RNNModel类创建一个基于LSTM的模型对象,传入LSTM层和词汇表大小
model = d2l.RNNModel(lstm_layer, len(vocab))
mode = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

在这里插入图片描述

在这里插入图片描述


http://www.mrgr.cn/news/40589.html

相关文章:

  • C++11智能智能指针解析
  • 8G 显存玩转书生大模型 Demo
  • 原来还有【快速排序】 qsort() 函数
  • 迪杰斯特拉算法 Dijkstra‘s Algorithm 详解
  • 音频内容创作难吗?5分钟了解NotebookLM自动生成播客:让内容创作变得如此简单
  • kubeadm部署k8s集群,版本1.23.6;并设置calico网络BGP模式通信,版本v3.25--未完待续
  • 【数据结构与算法】时间复杂度和空间复杂度例题
  • 【C语言指南】数据类型详解(下)——自定义类型
  • 【JavaEE】——多线程常用类
  • 你的虚拟猫娘女友,快来领取!--文心智能体平台
  • 将onnx模型中的类别信息导出到文本
  • JAVA认识异常
  • 数值计算的程序设计问题举例
  • 51单片机的智能家居【proteus仿真+程序+报告+原理图+演示视频】
  • 排水系统C++
  • 第5篇:勒索病毒自救指南----应急响应篇
  • 构建现代化社区医疗服务:SpringBoot平台
  • 【JavaEE】http/https 超级详解
  • 使用Materialize制作unity的贴图,Materialize的简单教程,Materialize学习日志
  • Raspberry Pi3B+之Rpanion(gst)和ffmpeg验证