当前位置: 首页 > news >正文

实时手势识别(2)- 基于关键点分类实现零样本图片的任意手势的识别

目录

前言

1.实现效果

2.关键点分类网络

3.KPNet训练测试数据准备

4.训练结果

4.1训练过程可视化

4.2验证集上的混淆矩阵

4.测试结果

4.1不同规模模型的测试结果对比

4.2分类结果投影到第一象限

4.3测试集上的混淆矩阵

4.4 二义性手势结果

4.5视频实测

5.零样本的任意手势识别

5.1任意手势关键点获取

5.2任意手势特征编码        

6.训练和测试关键代码

6.1dataset.py

6.2dataloader.py

6.3engine.py

6.4train.py

6.5test.py


前言

        先使用YOLOv8检测手部区域,然后使用YOLOv8-pose对放大的手部区域检测关键点,最后使用PointNet分类关键点,可以实现对任意手势的高精度实时识别

        对于非遮挡手势,仅需1W个参数,即可实现98%的准确率,极限情况下,仅需400个参数,可以达到80%的准确率。

手部关键点数据集准备:基于YOLOv8-pose的手部关键点检测(1)- 手部关键点数据集获取(数据集下载、数据清洗、处理与增强)

手部关键点检测模型训练:基于YOLOv8-pose的手部关键点检测(2)- 模型训练、结果分析和超参数优化

实现手部关键点实时检测

基于YOLOv8-pose的手部关键点检测(3)- 实现实时手部关键点检测


1.实现效果

        hand使用yolov8-m检测得到,resnt表示ResNet18的分类结果,shfnt表示用shufflenet_v2的分类结果,kpnet表示使用关键点分类网络的分类结果,conf是置信度。

        类别效果如下,将原始的18个类别映射为以下的14个类别

mapping_dict = {'call': 0, 'dislike': 1, 'fist': 2, 'four': 3, 'like': 4, 'mute': 5, 'ok': 6, 'one': 5, 'palm': 7, 'peace': 8, 'peace_inverted': 8, 'rock': 9, 'stop': 10, 'stop_inverted': 10, 'three': 11, 'three2': 12, 'two_up': 13, 'two_up_inverted': 13}


2.关键点分类网络

论文地址:PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation

项目地址:(tensorflow) https://github.com/charlesq34/pointnet

                  (pytorch) https://github.com/yanx27/Pointnet_Pointnet2_pytorch

        PointNet主要用于3D点云分类,这里将手部关键点看做3D点到2D平面的投影。如果有深度估计(例如mediapipe)可以取得更好的效果,可以更准确识别正反面手、左右手。

        mlp1_layersmlp2_layers分别表示编码层和解码层全连接的节点数,调整节点数,可以获得不同规模大小的模型。最小的模型nn仅需[8, 8, 8]和[8, 8, 8],48个神经元,即可实现80%的分类准确率。将PointNet简化为关键点(KeyPoint)分类网络(NetKPNet:

import torch
import torch.nn as nnclass KPNet(nn.Module):def __init__(self, num_classes, dropout_rate=0.3):super(KPNet, self).__init__()# shared-MLP1 in encode layers# mlp1_layers = [2, 64, 128, 1024]   # X# mlp2_layers = [1024, 512, 256, 128, num_classes]    # X# mlp1_layers = [2, 64, 128, 512]    # L# mlp2_layers = [512, 256, 128, num_classes]    # Lmlp1_layers = [2, 64, 128, 256]    # Mmlp2_layers = [256, 128, 64, num_classes]    # M# mlp1_layers = [2, 32, 64, 128]    # S# mlp2_layers = [128, 64, 32, num_classes]  # S# mlp1_layers = [2, 32, 32, 64]    # n# mlp2_layers = [64, 32, 32, num_classes]    # n# mlp1_layers = [2, 8, 8, 8]  # nn# mlp2_layers = [8, 8, 8, num_classes]  # nn# mlp1_layers = [2, 64, 128, 512]    # visual# mlp2_layers = [512, 256, 128, 2, num_classes]    # visualself.mlp1 = nn.ModuleList()self.mlp2 = nn.ModuleList()# MLP1 layers (Conv1d + BatchNorm1d + ReLU)for i in range(len(mlp1_layers) - 1):self.mlp1.append(nn.Conv1d(mlp1_layers[i], mlp1_layers[i + 1], 1))self.mlp1.append(nn.BatchNorm1d(mlp1_layers[i + 1]))self.mlp1.append(nn.ReLU())# MLP2 layers (Linear + BatchNorm1d + ReLU)for i in range(len(mlp2_layers) - 2):  # Exclude last layer for linearself.mlp2.append(nn.Linear(mlp2_layers[i], mlp2_layers[i + 1]))self.mlp2.append(nn.BatchNorm1d(mlp2_layers[i + 1]))self.mlp2.append(nn.ReLU())if i >= 1:  # Apply dropout after the third linear layerself.mlp2.append(nn.Dropout(p=dropout_rate))# Final layer without ReLU, dropout, or batch normalizationself.mlp2.append(nn.Linear(mlp2_layers[-2], mlp2_layers[-1]))def forward(self, x):# MLP1x = x.transpose(2, 1)  # (B, 2, N)for layer in self.mlp1:x = layer(x)x = torch.max(x, 2)[0]  # (B, 1024) global feature# MLP2# feat = Nonefor i, layer in enumerate(self.mlp2):x = layer(x)# if x.shape[1] == 2:#     feat = xreturn x    #, feat# 测试 KPNet
if __name__ == "__main__":B, N, C = 32, 100, 2  # Batch size = 32, 100 points, each with 2 dimensionsnum_classes = 10model = KPNet(num_classes)x = torch.randn(B, N, C)  # Random inputoutput = model(x)print("Output shape:", output.shape)  # Expected output shape: (32, 10)

3.KPNet训练测试数据准备

        将手部patch的关键点坐标归一化得到。patch如下图:

        将每个类别坐标统一保存为txt文件:

        每行保存一个patch的关键点信息:

        对一行关键点进行可视化,外观特征较为明显:

        关键点相对于点云更容易训练:点云需要随机采样(满足平移、旋转和置换不变性),而关键点的输入顺序是固定的,方向也是可以固定的。(根据需要,训练时可以加入旋转)。


4.训练结果

4.1训练过程可视化

        可以看到约40轮就收敛了,每轮训练约13秒(2W条关键点),大概10分钟就能训练完。

4.2验证集上的混淆矩阵

        主要错误:将three2错误预测为two(16个),将palm错误预测为stop(8个),将two错误预测为two_up。这也是符合预期的,这几类手势本身相似,会很容易受视角影响其余每类准确率都在99%以上


4.测试结果

4.1不同规模模型的测试结果对比

        X号对应于原PointNet的网络设计,nn号为每层最小神经元尝试,总共参数仅400多个,就可以达到80+%的分类准确率

模型

型号

size

(KB)

param

instance

(test)

P

(test)

R

(test)

mAP

(50:95)

Loss

(test)

dropout
nn14438122,7200.86140.82580.84740.46930
n437,862122,7200.98040.97610.97840.09970.3
S10318,722122,7200.98420.98030.98260.09120.3
M35740,834122,7200.98650.98400.98600.07620.3
L992224,218122,7200.98680.98360.98630.07310.3
X3,4001,701,514122,7200.98480.98480.98630.07370.3

4.2分类结果投影到第一象限

        将特征映射为2为特征,在第一象限进行投影,可以看到14个类别被有效分开。不过,由于将负值强行映射到第一象限,导致原点处存在聚集(这也是为什么,分类网络的全连接层最后一层,不要加Relu的原因):

4.3测试集上的混淆矩阵

        测试集上的效果与验证集上类似:

        归一化的混淆矩阵如下图所示,绝大部分手势准确率都在99%以上,fist只用200个训练,导致准确率最低:

4.4 二义性手势结果

        如下图,存在二义性的手势,由于光线等问题,分类网络预测为stop和three,但是利用关键点可以预测为four:

4.5视频实测

        使用分类网络可以区分正反面,可以学习到旋转等特征,比如call都是写着的,横着时候resnet依然可以识别出,但关键点分类无法识别。(因为训练时,没有加入旋转,这样关键点分类可以识别更多的手势语义。)


5.零样本的任意手势识别

5.1任意手势关键点获取

方式1:

        由于我们已经知道了标准手势,我们不需要在获取图片后,再提取关键点。我们可以自己在白板上直接画几个点表示关键点,然后加入随机抖动(限制一定范围内的)产生大量的手势关键点

方式2:

        我们已经训练好了YOLOv8-pose的手部关键点检测网络,我们只需自己用电脑摄像头,调整远近、角度、视角等,即可自动标注获取大量的标准手势。(如果是分类网络,需要不同背景、手部样式,关键点则不需要考虑这些)。

5.2任意手势特征编码        

        在训练完网络后,我们在前向推理中,获取分类结果前一层的特征,用于特征编码。给定一种标准手势,获取其特征编码值(该类手势可以获取几百上千个,然后进行特征聚类,获取更一般的特征);然后对于要判定的手势,计算其特征向量和标准手势编码特征向量的预先相似度。

class KPNet(nn.Module):def forward(self, x):# MLP1x = x.transpose(2, 1)  # (B, 2, N)for layer in self.mlp1:x = layer(x)x = torch.max(x, 2)[0]  # (B, 1024) global feature# MLP2feat = Nonefor i, layer in enumerate(self.mlp2):x = layer(x)if i == len(layer) - 2:feat = xreturn x, feat

6.训练和测试关键代码

6.1dataset.py

import os
import numpy as np
from torch.utils.data import Datasetclass KPNetDataset(Dataset):def __init__(self, data_dir, mapping_dict, reshape_dim=2, transform=None):""":param data_dir: 数据文件夹的根目录:param mapping_dict: 类别映射字典,key 为文件名,value 为类别值:param reshape_dim: 重新调整数据形状的维度,默认是 2,即将一行的数据 reshape 为 [-1, 2]:param transform: 数据增强的百分比,0~10%之间,None表示不做增强"""self.data_dir = data_dirself.mapping_dict = mapping_dictself.reshape_dim = reshape_dimself.transform = transformself.file_list = []self.labels = []# 构建 file_list 和 labels 列表self._prepare_file_index()def _prepare_file_index(self):"""构建 file_list 和 labels 列表,存储每一行数据的文件路径及其标签。"""for file_name, label in self.mapping_dict.items():file_path = os.path.join(self.data_dir, file_name + '.txt')if not os.path.exists(file_path):raise FileNotFoundError(f"File {file_path} does not exist.")with open(file_path, 'r') as file:lines = file.readlines()for _ in lines:self.file_list.append(file_path)self.labels.append(label)def __len__(self):return len(self.file_list)def __getitem__(self, idx):file_path = self.file_list[idx]label = self.labels[idx]# 读取指定文件的对应行数据with open(file_path, 'r') as file:lines = file.readlines()line = lines[idx - self.file_list.index(file_path)].strip()points = np.array(list(map(float, line.split()))).reshape(-1, self.reshape_dim)# 数据增强if self.transform:points = self._apply_transform(points)return points, labeldef _apply_transform(self, points):"""应用随机抖动数据增强,并将超出范围的点置为(0, 0)。:param points: 关键点数组,形状为 [n, 2]:return: 增强后的关键点数组"""jitter = np.random.uniform(-self.transform, self.transform, points.shape)points += jitter# 找到超出 [0, 1] 范围的点,并将其置为 (0, 0)mask = (points < 0) | (points > 1)points[np.any(mask, axis=1)] = [0, 0]return points# 主程序
if __name__ == "__main__":import cv2data_dir = r'./datasets/hagrid/yolo_pose_point/val'mapping_dict = {'call': 0,}# 初始化数据集dataset = KPNetDataset(data_dir, mapping_dict, reshape_dim=2, transform=0.02)# 打印数据集的大小print(f'Total items in dataset: {len(dataset)}')# 测试读取前 5 条数据for i in range(min(5, len(dataset))):data, label = dataset[i]print(f'Item {i}: Data shape: {data.shape}, Label: {label}')# print(data.tolist())# 创建白色背景的图像canvas_size = 224img = np.ones((canvas_size, canvas_size, 3), dtype=np.uint8) * 255# 将归一化坐标转换为画布上的坐标,并绘制蓝色点for point in data:x, y = pointx = int(x * canvas_size)y = int(y * canvas_size)cv2.circle(img, (x, y), radius=3, color=(255, 0, 0), thickness=-1)  # 蓝色点# 显示图像cv2.imshow(f'Item {i}', img)cv2.waitKey(0)  # 按任意键继续cv2.destroyAllWindows()

6.2dataloader.py

import torch
from torch.utils.data import DataLoader
from dataset import KPNetDatasetdef create_dataloader(data_dir, mapping_dict, phase='train', reshape_dim=2, batch_size=32, num_workers=4,transform=None):"""创建并返回一个 DataLoader。:param data_dir: 数据根目录。:param mapping_dict: 类别映射字典,key 为文件名,value 为类别值。:param phase: 当前数据集的阶段(train, val, test)。:param reshape_dim: 将数据 reshape 为 [n, reshape_dim] 的维度。:param batch_size: 批大小。:param num_workers: 用于数据加载的子进程数。:param transform: 数据增强的百分比,0~10%之间,None表示不做增强。:return: 返回 DataLoader。"""shuffle = True if phase == 'train' else Falsetransform = transform if phase == 'train' else Nonedataset = KPNetDataset(data_dir, mapping_dict, reshape_dim, transform)# 使用默认的collate_fndataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)return dataloader# 主程序(用于测试)
if __name__ == "__main__":data_dir = r'./datasets/hagrid/yolo_pose_point/val'mapping_dict = {'call': 0,# 可以添加更多映射}# 创建训练集 DataLoader(应用 2% 的数据增强)train_dataloader = create_dataloader(data_dir, mapping_dict, phase='train', transform=0.02, batch_size=4)# 测试读取训练集数据for i, (data, label) in enumerate(train_dataloader):print(f'Train Batch {i}:')print(f'Data shape: {data.shape}, Labels: {label.shape}')if i == 2:  # 仅测试前三个批次break# 创建验证集 DataLoader(不做数据增强)val_dataloader = create_dataloader(data_dir, mapping_dict, phase='val', batch_size=4)# 测试读取验证集数据for i, (data, label) in enumerate(val_dataloader):print(f'Validation Batch {i}:')print(f'Data shape: {data.shape}, Labels: {label.shape}')if i == 2:  # 仅测试前三个批次break

6.3engine.py

import torch
import torch.nn as nn
from sklearn.metrics import classification_report
from tqdm import tqdmdef train_one_epoch(model, dataloader, optimizer, criterion, device, label_dict):model.train()running_loss = 0.0correct_predictions = 0total_samples = 0all_labels = []all_predictions = []for data_batches, label_batches in tqdm(dataloader, desc="Training", unit="batch"):data_batches = data_batches.to(device).float()label_batches = label_batches.to(device)optimizer.zero_grad()outputs = model(data_batches)loss = criterion(outputs, label_batches)loss.backward()optimizer.step()running_loss += loss.item() * data_batches.size(0)_, predicted = torch.max(outputs, 1)correct_predictions += (predicted == label_batches).sum().item()total_samples += label_batches.size(0)all_labels.extend(label_batches.cpu().numpy())all_predictions.extend(predicted.cpu().numpy())epoch_loss = running_loss / total_samplesepoch_accuracy = correct_predictions / total_samples# 将标签索引映射回 label_dict 中的标签名称target_names = [label_dict[i] for i in sorted(label_dict.keys())]classification_metrics = classification_report(all_labels, all_predictions, target_names=target_names,output_dict=True, zero_division=0, digits=3)overall_recall = classification_metrics["macro avg"]["recall"]return epoch_loss, epoch_accuracy, overall_recall, classification_metricsdef test_one_epoch(model, dataloader, criterion, device, label_dict):model.eval()running_loss = 0.0correct_predictions = 0total_samples = 0all_labels = []all_predictions = []with torch.no_grad():for data_batches, label_batches in tqdm(dataloader, desc="Validation", unit="batch"):data_batches = data_batches.to(device).float()label_batches = label_batches.to(device)outputs = model(data_batches)loss = criterion(outputs, label_batches)running_loss += loss.item() * data_batches.size(0)_, predicted = torch.max(outputs, 1)correct_predictions += (predicted == label_batches).sum().item()total_samples += label_batches.size(0)all_labels.extend(label_batches.cpu().numpy())all_predictions.extend(predicted.cpu().numpy())epoch_loss = running_loss / total_samplesepoch_accuracy = correct_predictions / total_samples# 将标签索引映射回 label_dict 中的标签名称target_names = [label_dict[i] for i in sorted(label_dict.keys())]classification_metrics = classification_report(all_labels, all_predictions, target_names=target_names,output_dict=True, zero_division=0, digits=3)overall_recall = classification_metrics["macro avg"]["recall"]return epoch_loss, epoch_accuracy, overall_recall, classification_metricsif __name__ == "__main__":from KPNet import KPNetfrom dataloader import create_dataloadertrain_data_dir = r'./datasets/hagrid/yolo_pose_point/test'val_data_dir = r'./datasets/hagrid/yolo_pose_point/val'mapping_dict = {'call': 0, 'three': 1, 'palm': 2}num_classes = 3batch_size = 1024num_workers = 1learning_rate = 0.001device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 创建模型、损失函数、优化器model = KPNet(num_classes).to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 创建数据加载器train_dataloader = create_dataloader(train_data_dir, mapping_dict, phase='train', batch_size=batch_size,num_workers=num_workers, transform=0.02)val_dataloader = create_dataloader(val_data_dir, mapping_dict, phase='val', batch_size=batch_size,num_workers=num_workers)# 训练一轮train_loss, train_accuracy, train_recall, train_metrics = train_one_epoch(model, train_dataloader, optimizer,criterion, device, mapping_dict)# 测试一轮val_loss, val_accuracy, val_recall, val_metrics = test_one_epoch(model, val_dataloader, criterion, device,mapping_dict)

6.4train.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
import os
import logging
import time
from KPNet import KPNet
from dataloader import create_dataloader
from engine import train_one_epoch, test_one_epoch
from torch.utils.tensorboard import SummaryWriter
from tabulate import tabulate
import matplotlib.pyplot as pltdef train_pipline(train_data_dir, val_data_dir, mapping_dict, label_dict, num_classes, batch_size=256, num_workers=4,initial_lr=0.01, num_epochs=100, min_lr=0.00005, optimizer_choice='adam'):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 创建模型、损失函数、优化器model = KPNet(num_classes).to(device)criterion = nn.CrossEntropyLoss()if optimizer_choice.lower() == 'adam':optimizer = optim.Adam(model.parameters(), lr=initial_lr)else:optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9)# 指数衰减的学习率调度器scheduler = ExponentialLR(optimizer, gamma=0.95)# 创建数据加载器train_dataloader = create_dataloader(train_data_dir, mapping_dict, phase='train', batch_size=batch_size,num_workers=num_workers, transform=0.01)val_dataloader = create_dataloader(val_data_dir, mapping_dict, phase='val', batch_size=batch_size // 4,num_workers=num_workers)# 日志和模型保存配置timestamp = time.strftime("%Y%m%d-%H%M%S")model_save_dir = os.path.join('model_save', timestamp)os.makedirs(model_save_dir, exist_ok=True)log_dir = os.path.join('run_log', timestamp)os.makedirs(log_dir, exist_ok=True)log_filename = os.path.join(log_dir, 'training.log')logging.basicConfig(filename=log_filename, level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s')# 创建参数日志args_log_filename = os.path.join(log_dir, 'args.log')with open(args_log_filename, 'w') as f:f.write(f"train_data_dir: {train_data_dir}\n")f.write(f"val_data_dir: {val_data_dir}\n")f.write(f"mapping_dict: {mapping_dict}\n")f.write(f"label_dict: {label_dict}\n")f.write(f"num_classes: {num_classes}\n")f.write(f"batch_size: {batch_size}\n")f.write(f"num_workers: {num_workers}\n")f.write(f"initial_lr: {initial_lr}\n")f.write(f"num_epochs: {num_epochs}\n")f.write(f"min_lr: {min_lr}\n")f.write(f"optimizer_choice: {optimizer_choice}\n")# 创建 TensorBoard writerwriter = SummaryWriter(log_dir=log_dir)best_accuracy = 0.0best_recall = 0.0for epoch in range(num_epochs):train_loss, train_accuracy, train_recall, train_metrics = train_one_epoch(model, train_dataloader, optimizer,criterion, device, label_dict)val_loss, val_accuracy, val_recall, val_metrics = test_one_epoch(model, val_dataloader, criterion, device,label_dict)# 学习率调度scheduler.step()current_lr = scheduler.get_last_lr()[0]if current_lr < min_lr:for param_group in optimizer.param_groups:param_group['lr'] = min_lr# TensorBoard记录writer.add_scalar('Loss/train', train_loss, epoch)writer.add_scalar('Loss/val', val_loss, epoch)writer.add_scalar('Accuracy/train', train_accuracy, epoch)writer.add_scalar('Accuracy/val', val_accuracy, epoch)writer.add_scalar('Recall/train', train_recall, epoch)writer.add_scalar('Recall/val', val_recall, epoch)# 每个类别的准确率和召回率记录在 TensorBoard 中for category_index, category_name in label_dict.items():writer.add_scalar(f'Accuracy/train_{category_name}', train_metrics[category_name]['precision'], epoch)writer.add_scalar(f'Recall/train_{category_name}', train_metrics[category_name]['recall'], epoch)writer.add_scalar(f'Accuracy/val_{category_name}', val_metrics[category_name]['precision'], epoch)writer.add_scalar(f'Recall/val_{category_name}', val_metrics[category_name]['recall'], epoch)# 日志记录logging.info(f'Epoch {epoch + 1}/{num_epochs}, 'f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Train Recall: {train_recall:.4f}, 'f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation Recall: {val_recall:.4f}, 'f'Learning Rate: {current_lr:.6f}')print(f'Epoch {epoch + 1}/{num_epochs}, 'f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Train Recall: {train_recall:.4f}, 'f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation Recall: {val_recall:.4f}, 'f'Learning Rate: {current_lr:.6f}')# 保存最佳模型:先比准确率,再比召回率if val_accuracy > best_accuracy or (val_accuracy == best_accuracy and val_recall > best_recall):best_accuracy = val_accuracybest_recall = val_recallbest_model_wts = model.state_dict()best_model_path = os.path.join(model_save_dir, 'best_model.pth')torch.save(best_model_wts, best_model_path)logging.info(f'Best model saved with accuracy: {best_accuracy:.4f} and recall: {best_recall:.4f}')print(f'Best model saved with accuracy: {best_accuracy:.4f} and recall: {best_recall:.4f}')time.sleep(0.3)  # 防止 tqdm 输出错位# 保存最后一轮模型last_model_path = os.path.join(model_save_dir, 'last_model.pth')torch.save(model.state_dict(), last_model_path)logging.info('Last model saved.')print('Last model saved.')# 关闭 TensorBoard writerwriter.close()# 加载并验证最佳模型model.load_state_dict(torch.load(best_model_path))val_loss, val_accuracy, val_recall, val_metrics = test_one_epoch(model, val_dataloader, criterion, device,label_dict)print(f'Best Model Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation Recall: {val_recall:.4f}')print(f"Best Model Validation Metrics:\n")# 在命令行中以表格形式显示验证指标print_metrics_table(val_metrics, label_dict.values())# 将验证指标保存到日志文件val_best_model_log_filename = os.path.join(log_dir, 'val_best_model.log')with open(val_best_model_log_filename, 'w') as f:f.write(f'Best Model Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation Recall: {val_recall:.4f}\n')f.write(f"Best Model Validation Metrics:\n")log_content = tabulate([[category,f"{metrics['precision']:.3f}" if isinstance(metrics, dict) else 'N/A',f"{metrics['recall']:.3f}" if isinstance(metrics, dict) else 'N/A',f"{metrics['f1-score']:.3f}" if isinstance(metrics, dict) else 'N/A']for category, metrics in val_metrics.items()], headers=['Category', 'Precision', 'Recall', 'F1-Score'], tablefmt='grid')f.write(log_content)  # 写入到日志文件logging.info("\n" + log_content)  # 记录到日志# 绘制并保存训练过程的图像plot_and_save_separate_graphs(log_dir, num_epochs)def print_metrics_table(metrics, class_names):"""在命令行中以表格形式打印验证集的分类指标"""table = []for category in class_names:precision = metrics[category]['precision']recall = metrics[category]['recall']f1_score = metrics[category]['f1-score']table.append([category, f"{precision:.3f}", f"{recall:.3f}", f"{f1_score:.3f}"])# 打印表格print(tabulate(table, headers=['Category', 'Precision', 'Recall', 'F1-Score'], tablefmt='grid'))def plot_and_save_separate_graphs(log_dir, num_epochs):"""分别绘制并保存总损失和其他指标(准确率、召回率)的图像"""from tensorboard.backend.event_processing.event_accumulator import EventAccumulator# 加载 TensorBoard 日志event_acc = EventAccumulator(log_dir)event_acc.Reload()steps = range(num_epochs)# 总损失曲线train_loss = [scalar_event.value for scalar_event in event_acc.Scalars('Loss/train')]val_loss = [scalar_event.value for scalar_event in event_acc.Scalars('Loss/val')]plt.figure(figsize=(10, 6))plt.plot(steps, train_loss, label='Train Loss')plt.plot(steps, val_loss, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Train and Validation Loss Over Epochs')plt.legend()plt.savefig(os.path.join(log_dir, 'loss_curve.png'))# 总准确率和召回率曲线train_accuracy = [scalar_event.value for scalar_event in event_acc.Scalars('Accuracy/train')]val_accuracy = [scalar_event.value for scalar_event in event_acc.Scalars('Accuracy/val')]train_recall = [scalar_event.value for scalar_event in event_acc.Scalars('Recall/train')]val_recall = [scalar_event.value for scalar_event in event_acc.Scalars('Recall/val')]plt.figure(figsize=(10, 6))plt.plot(steps, train_accuracy, label='Train Accuracy')plt.plot(steps, val_accuracy, label='Validation Accuracy')plt.plot(steps, train_recall, label='Train Recall')plt.plot(steps, val_recall, label='Validation Recall')plt.xlabel('Epoch')plt.ylabel('Accuracy/Recall')plt.title('Accuracy and Recall Over Epochs')plt.legend()plt.savefig(os.path.join(log_dir, 'accuracy_recall_curve.png'))if __name__ == "__main__":train_data_dir = r'./datasets/hagrid/yolo_pose_point/test'val_data_dir = r'./datasets/hagrid/yolo_pose_point/val'# mapping_dict = {'call': 0, 'dislike': 1, 'fist': 2, 'four': 3, 'like': 4, 'mute': 5, 'ok': 6, 'one': 5, 'palm': 7,#                 'peace': 8, 'peace_inverted': 8, 'rock': 9, 'stop': 10, 'stop_inverted': 10, 'three': 11,#                 'three2': 12, 'two_up': 13, 'two_up_inverted': 13, 'no_gesture': 14}# label_dict = {0: 'six', 1: 'dislike', 2: 'fist', 3: 'four', 4: 'like', 5: 'one', 6: 'ok', 7: 'palm', 8: 'two',#               9: 'rock', 10: 'stop', 11: 'three', 12: 'three2', 13: 'two_up', 14: 'no_gesture'}mapping_dict = {'call': 0, 'dislike': 1, 'fist': 2, 'four': 3, 'like': 4, 'mute': 5, 'ok': 6,'one': 5, 'palm': 7, 'peace': 8, 'peace_inverted': 8, 'rock': 9, 'stop': 10,'stop_inverted': 10, 'three': 11, 'three2': 12, 'two_up': 13, 'two_up_inverted': 13}label_dict = {0: 'six', 1: 'dislike', 2: 'fist', 3: 'four', 4: 'like', 5: 'one', 6: 'ok', 7: 'palm', 8: 'two',9: 'rock', 10: 'stop', 11: 'three', 12: 'three2', 13: 'two_up'}num_classes = len(label_dict)batch_size = 512num_workers = 4initial_lr = 0.01num_epochs = 100min_lr = 0.0001optimizer_choice = 'adam'  # or 'sgd'train_pipline(train_data_dir, val_data_dir, mapping_dict, label_dict, num_classes, batch_size, num_workers, initial_lr,num_epochs, min_lr, optimizer_choice)

6.5test.py

import torch
import torch.nn as nn
import os
import numpy as np
import time
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from KPNet import KPNet
from dataloader import create_dataloader
from engine import test_one_epoch
from tabulate import tabulatedef test_pipeline(test_data_dir, model_path, mapping_dict, label_dict, num_classes, batch_size=256, num_workers=4):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型model = KPNet(num_classes).to(device)model.load_state_dict(torch.load(model_path))model.eval()# 创建数据加载器test_dataloader = create_dataloader(test_data_dir, mapping_dict, phase='val', batch_size=batch_size // 4,num_workers=num_workers)# 获取模型时间戳目录model_timestamp = os.path.basename(os.path.dirname(model_path))output_dir = os.path.join('output', model_timestamp)# 确保 output 目录存在os.makedirs(output_dir, exist_ok=True)# 查找下一个测试文件夹编号test_folders = [f for f in os.listdir(output_dir) if f.startswith('test_')]test_numbers = [int(f.split('_')[1]) for f in test_folders if f.split('_')[1].isdigit()]next_test_number = max(test_numbers) + 1 if test_numbers else 1test_output_dir = os.path.join(output_dir, f'test_{next_test_number:02d}')os.makedirs(test_output_dir, exist_ok=True)# 运行测试criterion = nn.CrossEntropyLoss()test_loss, test_accuracy, test_recall, test_metrics = test_one_epoch(model, test_dataloader, criterion, device,label_dict)# 打印并保存测试结果test_results_log_filename = os.path.join(test_output_dir, 'test_results.log')with open(test_results_log_filename, 'w') as f:f.write(f'Test Loss: {test_loss:.4f}\n')f.write(f'Test Accuracy: {test_accuracy:.4f}\n')f.write(f'Test Recall: {test_recall:.4f}\n')f.write(f"Test Metrics:\n")log_content = tabulate([[category,f"{metrics['precision']:.3f}" if isinstance(metrics, dict) else 'N/A',f"{metrics['recall']:.3f}" if isinstance(metrics, dict) else 'N/A',f"{metrics['f1-score']:.3f}" if isinstance(metrics, dict) else 'N/A',metrics['support']]for category, metrics in test_metrics.items() if category != 'accuracy'], headers=['Category', 'Precision', 'Recall', 'F1-Score', 'Instance'], tablefmt='grid')f.write(log_content)# 计算并保存混淆矩阵all_labels = []all_predictions = []for data_batches, label_batches in test_dataloader:data_batches = data_batches.to(device).float()label_batches = label_batches.to(device)with torch.no_grad():outputs = model(data_batches)_, predicted = torch.max(outputs, 1)all_labels.extend(label_batches.cpu().numpy())all_predictions.extend(predicted.cpu().numpy())cm = confusion_matrix(all_labels, all_predictions, labels=list(label_dict.keys()))cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]plt.figure(figsize=(10, 8))sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_dict.values(), yticklabels=label_dict.values())plt.title('Confusion Matrix')plt.xlabel('Predicted Label')plt.ylabel('True Label')plt.savefig(os.path.join(test_output_dir, 'confusion_matrix.png'))plt.figure(figsize=(10, 8))sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', xticklabels=label_dict.values(),yticklabels=label_dict.values())plt.title('Normalized Confusion Matrix')plt.xlabel('Predicted Label')plt.ylabel('True Label')plt.savefig(os.path.join(test_output_dir, 'normalized_confusion_matrix.png'))# 计算 mAPprecision_values = []recall_values = []for category in label_dict.values():if category in test_metrics:precision_values.append(test_metrics[category]['precision'])recall_values.append(test_metrics[category]['recall'])mAP50 = np.mean([p >= 0.5 for p in precision_values])mAP75 = np.mean([p >= 0.75 for p in precision_values])mAP50_95 = np.mean([p for p in precision_values])with open(test_results_log_filename, 'a') as f:f.write(f"\nmAP50: {mAP50:.4f}\n")f.write(f"mAP75: {mAP75:.4f}\n")f.write(f"mAP50:95: {mAP50_95:.4f}\n")if __name__ == "__main__":test_data_dir = r'./datasets/hagrid/yolo_pose_point/train'model_path = r'./KPNet/model_save/20240816-172047/best_model.pth'mapping_dict = {'call': 0, 'dislike': 1, 'fist': 2, 'four': 3, 'like': 4, 'mute': 5, 'ok': 6, 'one': 5, 'palm': 7,'peace': 8, 'peace_inverted': 8, 'rock': 9, 'stop': 10, 'stop_inverted': 10, 'three': 11,'three2': 12, 'two_up': 13, 'two_up_inverted': 13}label_dict = {0: 'six', 1: 'dislike', 2: 'fist', 3: 'four', 4: 'like', 5: 'one', 6: 'ok', 7: 'palm', 8: 'two',9: 'rock', 10: 'stop', 11: 'three', 12: 'three2', 13: 'two_up'}num_classes = len(label_dict)batch_size = 512num_workers = 4test_pipeline(test_data_dir, model_path, mapping_dict, label_dict, num_classes, batch_size, num_workers)

http://www.mrgr.cn/news/5814.html

相关文章:

  • 大数据面试-Zookeeper
  • Stable Diffusion【应用篇】【艺术写真】:超高相似度人物换脸写真,IP-Adapter与InstantID完美结合
  • docker安装mysql使用宿主机网络
  • vue3模拟生成并渲染10万条数据,并实现本地数据el-table表格分页
  • Ant-Design-Vue快速上手指南+排坑
  • IPO雷达丨具备独特产业链布局优势,港迪技术成长性较强
  • 我的新项目又来咯!
  • 超低排放验收流程的全方位指南
  • 为什么企业跨国组网建议用SD-WAN?
  • 前端宝典十二:node基础模块和常用API
  • 每日一问:为什么MySQL索引使用B+树? 第4版 (含时间复杂度对比表格)
  • 一NULL为甚?
  • Redis管道
  • 提升代码可读性的十八条建议2
  • LNMP学习
  • C学习(数据结构)--> 实现顺序结构二叉树
  • 在亚马逊云科技上提取视频内容并利用AI大模型开发视频内容问答服务
  • 海山数据库(He3DB)源码详解:CommitTransaction函数源码详解
  • Shell编程之条件语句
  • 开发者空间实践指导:基于华为云3大PaaS主流服务轻松实现文字转换语音