当前位置: 首页 > news >正文

深度学习·wandb

wandb

一个好用的可视化训练过程和调参工具,建议在深度学习中使用,语法来说更加方便

前置工作

这里是一些简单的网络结构,用于测试

数据集:

  • Kaggle上HeartDisease的0-1分类问题
    df=pd.read_csv('../data/heart_attack/heart.csv')

数据集的迭代:

  • X=torch.tensor(X.values,device=config.device,dtype=torch.float32) y=torch.tensor(y.values,device=config.device,dtype=torch.float32).reshape(-1,1) dataset=TensorDataset(X,y) dataloader=DataLoader(dataset,batch_size=config.batch_size,shuffle=True)

简单的DNN

class DNN(nn.Module):def __init__(self,input_size,hidden_size,dropout:float):super().__init__()self.input_size=input_sizeself.hidden_size=hidden_sizeself.fc1=nn.Linear(self.input_size,self.hidden_size)self.fc2=nn.Linear(self.hidden_size,self.hidden_size)self.fc3=nn.Linear(self.hidden_size,1)self.dropout=nn.Dropout(dropout)def forward(self,x):x=F.leaky_relu(self.fc1(x))x=self.dropout(x)x=F.leaky_relu(self.fc2(x))x=self.dropout(x)x=self.fc3(x)return x

wandb监视训练过程

使用login()登陆

import os
os.environ["WANDB_API_KEY"] = "xxxx"
wandb.login(key=os.environ['WANDB_API_KEY'])

初始化wandb

  • 建议使用系统时间:
    current_time = datetime.now()standard_time = current_time.strftime("%Y-%m-%d %H:%M:%S")name=standard_time
  • 初始化:
    注意保存wand.run.id方便继续监视该模型
    wandb.init(project=config.project_name,name=name,config=config.__dict__)# 转换为dict    model_run_id=wandb.run.id

训练流程中记录参数

    for epoch in tqdm(range(config.epochs)):for X,y in dataloader:# 反向传播# 评估指标wandb.log({'epoch':epoch+1,'val_acc':val_acc,'best_acc':best_metric})wandb.finish()

wandb.log从接口收到对应参数,wandb.finish()完成记录,主要不要漏掉finish

继续训练

  • 提供run.id并将resume设置为must
    wandb.init(project=config.project_name,id=model_run_id,resume='must')

Artifact工件

工件可以是代码也可以是数据集
第一个参数是名称,第二个是类型

wandb.init(project=config.project_name,id=model_run_id,resume='must')
arti_dataset=wandb.Artifact('HeartDisease',type='dataset')
arti_dataset.add_dir('../data/heart_attack/')
wandb.log_artifact(arti_dataset)
```python
arti_code=wandb.Artifact('ipynb',type='code')
arti_code.add_file('./wand_test.ipynb')
wandb.log_artifact(arti_code)
wandb.finish()

Table

可视化分析

wandb.init(project=config.project_name,id=model_run_id,resume='must')
good_cases=wandb.Table(columns=['id','GroundTrue','Prediction'])
bad_cases=wandb.Table(columns=['id','GroundTrue','Prediction'])

在代码中加入如下:

good_cases.add_data(i,y,prediction)

一般是用于比对feature、label和prediction


http://www.mrgr.cn/news/40375.html

相关文章:

  • LeetCode题练习与总结:二叉树的所有路径--257
  • GEE教程:下载大气环流模型 (GCM)全球数据动态项目 NASA NEX GDDP 产品数据库
  • 2024年使用宝塔面板轻松部署Java Web
  • 关于4G-Cat.1模组Air780E选型,注意事项盘点
  • 【题解】【递推】—— [NOIP2003 普及组] 栈
  • ElasticSearch分词器、相关性详解与聚合查询实战
  • Win10之Ubuntu22.04(主机)与Virtual-BOX(宿主win10)网络互通调试(七十九)
  • Linux: network: /proc/net/sockstat 解读
  • Elasticsearch学习笔记(3)
  • C++14:通过make_index_sequence实现将tuple转换为array
  • Python 并发新境界:探索 `multiprocessing` 模块的无限可能
  • 通信工程学习:什么是DQDB分布式队列双总线
  • 禁止某驱动软件自动联网检测更新
  • 企望制造ERP系统存在RCE漏洞
  • 架构师知识梳理(12):知识产权
  • [贪心 + dp] 疯狂的火神
  • 模型推理实践与工具详解
  • 数据库 - Mongo数据库
  • 付费计量系统通用功能(5)
  • 【拥抱AIGC】通义灵码扩展管理