当前位置: 首页 > news >正文

欺诈文本分类微调(七)—— lora单卡二次调优

1. 前言

模型训练是一个不断调优的过程,这注定了我们的需要多次跑同一个训练过程。在前文欺诈文本分类微调(六):Lora单卡跑的整个训练过程中,基本可以分为几步:

  1. 数据加载
  2. 数据预处理
  3. 模型加载
  4. 定义lora参数
  5. 插入微调矩阵
  6. 定义训练参数
  7. 构建训练器开始训练

这个流程基本是固定的,而训练调优过程中需要调整的主要是以下这些项:

  1. 输入和输出:数据路径,模型路径,输出路径
  2. 参数:lora参数,训练参数

因此,我们将整个训练过程中基本不变的部分提取到trainer.py中。内容如下所示:

def load_jsonl(path):with open(path, 'r') as file:data = [json.loads(line) for line in file]return pd.DataFrame(data)def preprocess(item, tokenizer, max_length=2048):input_ids, attention_mask, labels = [], [], []system_message = "You are a helpful assistant."user_message = item['instruction'] + item['input']assistant_message = json.dumps({"is_fraud":item["label"]}, ensure_ascii=False)instruction = tokenizer(f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False)  response = tokenizer(assistant_message, add_special_tokens=False)input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # -100是一个特殊的标记,用于指示指令部分的token不应参与损失计算labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  # 对输入长度做一个限制保护,超出截断return {"input_ids": input_ids[:max_length],"attention_mask": attention_mask[:max_length],"labels": labels[:max_length]}def load_dataset(train_path, eval_path, tokenizer):train_df = load_jsonl(train_path)train_ds = Dataset.from_pandas(train_df)train_dataset = train_ds.map(lambda x: preprocess(x, tokenizer), remove_columns=train_ds.column_names)eval_df = load_jsonl(eval_path)eval_ds = Dataset.from_pandas(eval_df)eval_dataset = eval_ds.map(

http://www.mrgr.cn/news/8077.html

相关文章:

  • 机器学习解决方案(Datawhale X 李宏毅苹果书 AI夏令营)
  • Xilinx官方XDMA驱动解析
  • Kubernetes 中如何对 etcd 进行备份和还原
  • java swagger解析解决[malformed or unreadable swagger supplied]
  • fl studio mobile2024最新官方版V4.6.8安卓版+iOS苹果版
  • “2025深圳电子信息展”带您感受科技创新如何改变世界
  • MySQL最左匹配原则
  • 景联文科技:专业人像采集服务,助力人像采集在多领域应用
  • HarmonyOS NEXT 地图服务中‘我的位置’功能全解析
  • 前端面试题整理-webpack
  • 每日OJ_牛客_淘宝网店(日期模拟)
  • 计算机毕业设计Flink+Hadoop广告推荐系统 广告预测 广告数据分析可视化 广告爬虫 大数据毕业设计 Spark Hive 深度学习 机器学
  • Day98:云上攻防-云原生篇K8s安全Config泄漏Etcd存储Dashboard鉴权Proxy暴露
  • 游戏引擎phaser.js3的使用之玩家和静态物理组碰撞
  • JavaScript初级——文档的加载
  • 深入理解 Go 语言并发编程底层原理
  • SSRF漏洞学习
  • MySQL主从复制之GTID模式
  • 未来已来:探索机器学习如何重塑人工智能的未来方向
  • 如何从电脑/外部硬盘驱动器/USB 驱动器/内存卡恢复数据?