Upload files to "/"

This commit is contained in:
2026-01-29 19:03:32 +08:00
parent 9801ee14b0
commit e1cdb9d19e
4 changed files with 518 additions and 0 deletions

154
test.py Normal file
View File

@@ -0,0 +1,154 @@
from data_process import build_dataloader
from model import TransClassifier
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
import pandas as pd
import numpy as np
import os
import json
from datetime import datetime
import gc
from tqdm import tqdm
import warnings
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
warnings.filterwarnings("ignore")
def test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path="best_ckpt.pth", device="cuda"):
"""
测试模型在测试集上的表现
Args:
backbone_dir: 预训练模型目录
deal_folder: 成交数据文件夹
not_deal_folder: 非成交数据文件夹
batch_size: 批量大小
ckpt_path: 模型 checkpoint 路径
device: 运行设备
"""
# 加载测试数据
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
test_loader = data_dict["test"]
print(f"Test data loaded successfully. Test samples: {len(test_loader.dataset)}")
# 加载 tokenizer 和模型
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
model = TransClassifier(backbone_dir, device)
model.to(device)
# 加载训练好的模型权重
if os.path.exists(ckpt_path):
model.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Model loaded from {ckpt_path}")
else:
print(f"Warning: {ckpt_path} not found. Using untrained model.")
# 测试模型
model.eval()
all_ids = []
all_preds = []
all_labels = []
test_loss = 0.0
loss_func = nn.CrossEntropyLoss()
pbar = tqdm(test_loader, desc="Testing")
with torch.inference_mode():
for batch_idx, (ids, texts, labels) in enumerate(pbar):
all_ids.extend(ids)
labels = labels.to(device)
# 处理输入数据
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
loss = loss_func(outputs, labels)
test_loss += loss.item()
# 计算预测结果
preds = torch.argmax(outputs, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
# 清理内存
del texts, labels, outputs, loss
torch.cuda.empty_cache()
gc.collect()
# 计算评估指标
avg_loss = test_loss / len(test_loader)
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average="weighted")
recall = recall_score(all_labels, all_preds, average="weighted")
f1 = f1_score(all_labels, all_preds, average="weighted")
cm = confusion_matrix(all_labels, all_preds)
# 打印评估结果
print("\n=== Test Results ===")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print("\nConfusion Matrix:")
print(cm)
print("\n=== Class-wise Metrics ===")
print("Class 0 (Not Deal):")
print(f" Precision: {precision_score(all_labels, all_preds, average=None)[0]:.4f}")
print(f" Recall: {recall_score(all_labels, all_preds, average=None)[0]:.4f}")
print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[0]:.4f}")
print("\nClass 1 (Deal):")
print(f" Precision: {precision_score(all_labels, all_preds, average=None)[1]:.4f}")
print(f" Recall: {recall_score(all_labels, all_preds, average=None)[1]:.4f}")
print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[1]:.4f}")
# 保存测试结果
test_results = {
"average_loss": avg_loss,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1_score": f1,
"confusion_matrix": cm.tolist(),
"class_0_precision": precision_score(all_labels, all_preds, average=None)[0],
"class_0_recall": recall_score(all_labels, all_preds, average=None)[0],
"class_0_f1": f1_score(all_labels, all_preds, average=None)[0],
"class_1_precision": precision_score(all_labels, all_preds, average=None)[1],
"class_1_recall": recall_score(all_labels, all_preds, average=None)[1],
"class_1_f1": f1_score(all_labels, all_preds, average=None)[1],
"test_samples": len(all_labels),
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
# 保存预测结果
pred_results = {
"ids": all_ids,
"predictions": all_preds,
"true_labels": all_labels
}
pred_df = pd.DataFrame(pred_results)
pred_df.to_csv("test_predictions.csv", index=False, encoding="utf-8")
# 保存为 JSON 文件
with open("test_results.json", "w", encoding="utf-8") as f:
json.dump(test_results, f, ensure_ascii=False, indent=2)
print("\nTest results saved to test_results.json")
return test_results
if __name__ == "__main__":
# 配置参数
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
deal_folder = "deal"
not_deal_folder = "not_deal"
batch_size = 8
ckpt_path = "best_ckpt.pth"
device = "cuda"
# 运行测试
test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path, device)