当前位置: 首页 > news >正文

gpt-2语言模型训练

一、通过下载对应的语言模型数据集 

1.1 根据你想让回答的内容,针对性下载对应的数据集,我下载的是个医疗问答数据集

1.2 针对你要用到的字段信息进行处理,然后把需要处理的数据丢给模型去训练,这个模型我是直接从GPT2的网站下载下来的依赖的必要文件截图如下:

二、具体代码样例实现:

import os
import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, TextDataset, \DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelForCausalLM# 读取CSV文件
data_path = '内科500.csv'  # 替换为你的CSV文件路径
df = pd.read_csv(data_path, encoding='ISO-8859-1')# 将数据集转换为适合训练的格式
def preprocess_dialogues(df):conversations = []for index, row in df.iterrows():department = row['department']title = row['title']ask = row['ask']answer = row['answer']# 将每条问答对转换为连续的对话context = f"科室: {department}\n问题: {title}\n提问: {ask}\n回答: {answer}\n"conversations.append(context)return conversationsconversations = preprocess_dialogues(df)# 保存对话数据到文本文件
train_file_path = 'train_data.txt'
with open(train_file_path, 'w', encoding='utf-8') as file:for conversation in conversations:file.write(conversation + '\n')# 加载预训练模型和tokenizer
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-model')
model = GPT2LMHeadModel.from_pretrained('./gpt2-model')# 准备数据集
def load_dataset(file_path, tokenizer, block_size=128):return TextDataset(tokenizer=tokenizer,file_path=file_path,block_size=block_size)train_dataset = load_dataset(train_file_path, tokenizer)# 数据整理器
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm=False
)# 训练参数
training_args = TrainingArguments(output_dir='./results',overwrite_output_dir=True,num_train_epochs=3,per_device_train_batch_size=4,save_steps=10_000,save_total_limit=2,resume_from_checkpoint=True  # 从检查点恢复训练
)# 创建Trainer
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=train_dataset
)last_checkpoint = None
if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):last_checkpoint = training_args.output_dir
# 开始训练
trainer.train(resume_from_checkpoint=last_checkpoint)

http://www.mrgr.cn/news/5277.html

相关文章:

  • 物联网设备心跳源码-SAAS本地化及未来之窗行业应用跨平台架构
  • 标准库标头 <string_view> (C++17)学习
  • 5步掌握Python Django结合K-means算法进行豆瓣书籍可视化分析
  • LabVIEW深度监测系统
  • 数据结构--单链表
  • 多功能秒达工具箱全开源源码,可自部署且完全开源的中文工具箱
  • 投资伦敦银一般看什么点位做单?
  • sqlite3基本操作/数据库编程
  • uniapp中 使用 VUE3 组合式API 怎么接收上一个页面传递的参数
  • XSS-games
  • Java TCP练习2
  • 【系统架构设计】软件架构设计(1)
  • LeeCode Practice Journal | Day50_Graph01
  • 【STM32】C语言基础补充
  • [mongodb][查询]MongoDb 模糊查询
  • 开闭原则(Open-Closed Principle, OCP)详解
  • RabbitMQ的基础概念介绍
  • dp题目集合
  • Windows Microsoft Edge 浏览器 配置【密码】
  • Python实战:如何使用K-means算法进行餐馆满意度NLP情感分析