当前位置: 首页 > news >正文

大模型是如何把向量解码成文字输出的

hidden state 向量

当我们把一句话输入模型后,例如 “Hello world”:

token IDs: [15496, 995]

经过 Embedding + Transformer 层后,会得到每个 token 的中间表示,形状为:

hidden_states: (batch_size, seq_len, hidden_dim) 比如:  (1, 2, 768)

这是 Transformer 层的输出,即每个 token 的向量表示。

hidden state → logits:映射到词表空间

🔹 使用 输出投影矩阵(通常是 embedding 的转置)
为了从 hidden state 还原出词,我们需要得到它在词表上每个 token 的“分数”,这叫 logits。实现方式如下:

logits = hidden_state @ W_out.T + b

其中:

  • W_out 是词嵌入矩阵(Embedding matrix),形状为 (vocab_size, hidden_dim)
  • @ 是矩阵乘法,hidden_state 形状是 (seq_len, hidden_dim)
  • 得到的 logits 形状是 (seq_len, vocab_size)

所以,每个位置的 hidden state 都被映射成一个 词表大小的分布。

logits → token ID:选出最可能的 token

现在每个位置我们都有了一个 logits 向量,例如:

logits = [2.1, -0.5, 0.3, 6.9, ...]  # 长度 = vocab_size

有几种选择方式:

方法说明
argmax(logits)选最大值,对应 greedy decoding
softmax → sample转成概率分布后随机采样
top-k sampling从 top-k 个中采样,控制多样性
top-p (nucleus)从累计概率在 p 范围内采样

例如:

probs = softmax(logits)
token_id = torch.argmax(probs).item()  

token ID → token 字符串片段

token ID 其实对应的是某个词表里的编号,比如:

tokenizer.convert_ids_to_tokens(50256)  # 输出: <|endoftext|>
tokenizer.convert_ids_to_tokens(15496)  # 输出: "Hello"

如果是多个 token ID,可以:

tokenizer.convert_ids_to_tokens([15496, 995])  # 输出: ["Hello", " world"]

tokens → 拼接成文本(decode)

tokens 是“子词”或“子字符”,例如:

["Hel", "lo", " world", "!"]

通过 tokenizer.decode() 会自动合并它们为字符串:

tokenizer.decode([15496, 995])  # 输出: "Hello world"

它会处理空格、子词连接等细节,恢复为人类可读的句子。

多轮生成:把预测作为输入继续生成

在生成任务(如 GPT)中,模型是逐 token 生成的。
流程如下:

输入: "你好"
↓
tokenize → [token IDs]
↓
送入模型 → 得到下一个 token 的 logits
↓
选出 token ID → decode 成文字
↓
拼接到输入后,继续送入模型 → 下一轮生成
↓
...
直到生成 EOS(终止符)或达到最大长度

总结流程图:

(1) 输入文本 → tokenizer → token IDs
(2) token IDs → Embedding → hidden_states(中间层向量)
(3) hidden_states × W.T → logits(词表得分)
(4) logits → sampling → token ID
(5) token ID → token → decode → 文本
(6) 拼接文本 → 重复生成(自回归)

示例代码

"""
大语言模型解码过程详解
===========================
本示例展示了大语言模型如何将隐藏状态向量解码成文本输出
使用GPT-2模型作为演示,展示从输入文本到预测下一个token的完整流程
"""import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer# 设置随机种子,确保结果可复现
torch.manual_seed(42)def display_token_probabilities(probabilities, tokens, top_k=5):"""可视化展示token的概率分布(仅展示top_k个)"""# 获取前k个最大概率及其索引top_probs, top_indices = torch.topk(probabilities, top_k)top_probs = top_probs.detach().numpy()top_tokens = [tokens[idx] for idx in top_indices]print(f"\n前{top_k}个最可能的下一个token:")for token, prob in zip(top_tokens, top_probs):print(f"  {token:15s}: {prob:.6f} ({prob * 100:.2f}%)")# 可视化概率分布plt.figure(figsize=(10, 6))plt.bar(top_tokens, top_probs)plt.title(f"Top {top_k} The probability distribution of the next token")plt.ylabel("probability")plt.xlabel("Token")plt.xticks(rotation=45)plt.tight_layout()plt.show()def main():print("Step 1: 加载预训练模型和分词器")# 从Hugging Face加载预训练的GPT-2模型和分词器tokenizer = GPT2Tokenizer.from_pretrained("gpt2")model = GPT2LMHeadModel.from_pretrained("gpt2")model.eval()  # 将模型设置为评估模式print("\nStep 2: 准备输入文本")input_text = "Artificial intelligence is"print(f"输入文本: '{input_text}'")# 将输入文本转换为模型需要的格式inputs = tokenizer(input_text, return_tensors="pt")input_ids = inputs["input_ids"]attention_mask = inputs["attention_mask"]# 展示分词结果tokens = tokenizer.convert_ids_to_tokens(input_ids[0])print(f"分词结果: {tokens}")print(f"Token IDs: {input_ids[0].tolist()}")print("\nStep 3: 运行模型前向传播")# 使用torch.no_grad()避免计算梯度,节省内存with torch.no_grad():# output_hidden_states=True 让模型返回所有层的隐藏状态outputs = model(input_ids=input_ids,attention_mask=attention_mask,output_hidden_states=True)# 获取最后一层的隐藏状态# hidden_states的形状: [层数, batch_size, seq_len, hidden_dim]last_layer_hidden_states = outputs.hidden_states[-1]print(f"隐藏状态形状: {last_layer_hidden_states.shape}")# 获取序列中最后一个token的隐藏状态last_token_hidden_state = last_layer_hidden_states[0, -1, :]print(f"最后一个token的隐藏状态形状: {last_token_hidden_state.shape}")print(f"隐藏状态前5个值: {last_token_hidden_state[:5].tolist()}")print("\nStep 4: 手动计算logits")# 从模型中获取输出嵌入矩阵的权重lm_head_weights = model.get_output_embeddings().weight  # [vocab_size, hidden_dim]print(f"语言模型输出嵌入矩阵形状: {lm_head_weights.shape}")# 通过点积计算logits# logits代表每个词汇表中token的分数logits = torch.matmul(last_token_hidden_state, lm_head_weights.T)  # [vocab_size]print(f"Logits形状: {logits.shape}")print(f"Logits值域: [{logits.min().item():.4f}, {logits.max().item():.4f}]")print("\nStep 5: 应用softmax转换为概率")# 使用softmax将logits转换为概率分布probabilities = torch.softmax(logits, dim=0)print(f"概率总和: {probabilities.sum().item():.4f}")  # 应该接近1# 找出概率最高的tokennext_token_id = torch.argmax(probabilities).item()next_token = tokenizer.decode([next_token_id])print(f"预测的下一个token (ID: {next_token_id}): '{next_token}'")# 展示完整的句子complete_text = input_text + next_tokenprint(f"生成的文本: '{complete_text}'")# 展示top-k的概率分布display_token_probabilities(probabilities, tokenizer.convert_ids_to_tokens(range(len(probabilities))), top_k=10)print("\nStep 6: 比较与模型内置解码结果")# 获取模型内置的logits输出model_outputs = model(input_ids=input_ids, attention_mask=attention_mask)model_logits = model_outputs.logitsprint(f"模型输出的logits形状: {model_logits.shape}")# 获取最后一个token位置的logitslast_token_model_logits = model_logits[0, -1, :]# 验证我们手动计算的logits与模型输出的logits是否一致is_close = torch.allclose(logits, last_token_model_logits, rtol=1e-4)print(f"手动计算的logits与模型输出的logits是否一致: {is_close}")# 如果不一致,计算差异if not is_close:diff = torch.abs(logits - last_token_model_logits)print(f"最大差异: {diff.max().item():.8f}")print(f"平均差异: {diff.mean().item():.8f}")print("\nStep 7: 使用模型进行文本生成")# 使用模型的generate方法生成更多文本# 生成时传递 attention_mask 和 pad_token_idgenerated_ids = model.generate(input_ids,max_length=input_ids.shape[1] + 10,  # 生成10个额外的tokentemperature=1.0,do_sample=True,top_k=50,top_p=0.95,num_return_sequences=1,attention_mask=attention_mask,  # 添加 attention_maskpad_token_id=tokenizer.eos_token_id  # 明确设置 pad_token_id 为 eos_token_id)generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)print(f"模型生成的文本:\n'{generated_text}'")if __name__ == "__main__":main()

http://www.mrgr.cn/news/97678.html

相关文章:

  • 畅游Diffusion数字人(21):基于Wan2.1的音频驱动数字人FantasyTalking
  • 急速实现Anaconda/Miniforge虚拟环境的克隆和迁移
  • mysql镜像创建docker容器,及其可能遇到的问题
  • 数据结构和算法(十二)--最小生成树
  • 【组件封装-优化】vue+element plus:二次封装select组件,实现下拉列表有分页、自定义是否可搜索的一系列功能
  • 【杂谈】Godot4.4导出到Android平台(正式导出)
  • 最新版PhpStorm超详细图文安装教程,带补丁包(2025最新版保姆级教程)
  • c语言 文件操作
  • SQL语法进阶篇(二),数据库复杂查询——窗口函数
  • 【蓝桥杯2024省B】好数 三种解法全解析 | C/C++暴力法→剪枝优化→构造法演进
  • GZ036区块链卷一 EtherStore合约漏洞详解
  • React 列表渲染
  • Java 大视界 -- 基于 Java 的大数据分布式缓存技术在电商高并发场景下的性能优化(181)
  • 算法精讲【整数二分】(实战教学)
  • Kotlin学习
  • debian12安装mysql5.7.42(deb)
  • C++中数组的概念
  • 【Linux高级IO(三)】Reactor
  • fastGPT—前端开发获取api密钥调用机器人对话接口(HTML实现)
  • java线程安全-单例模式-线程通信