Upload files to "/"
This commit is contained in:
154
test.py
Normal file
154
test.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from data_process import build_dataloader
|
||||
from model import TransClassifier
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
import gc
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
def test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path="best_ckpt.pth", device="cuda"):
|
||||
"""
|
||||
测试模型在测试集上的表现
|
||||
|
||||
Args:
|
||||
backbone_dir: 预训练模型目录
|
||||
deal_folder: 成交数据文件夹
|
||||
not_deal_folder: 非成交数据文件夹
|
||||
batch_size: 批量大小
|
||||
ckpt_path: 模型 checkpoint 路径
|
||||
device: 运行设备
|
||||
"""
|
||||
# 加载测试数据
|
||||
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
|
||||
test_loader = data_dict["test"]
|
||||
print(f"Test data loaded successfully. Test samples: {len(test_loader.dataset)}")
|
||||
|
||||
# 加载 tokenizer 和模型
|
||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||
model = TransClassifier(backbone_dir, device)
|
||||
model.to(device)
|
||||
|
||||
# 加载训练好的模型权重
|
||||
if os.path.exists(ckpt_path):
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
print(f"Model loaded from {ckpt_path}")
|
||||
else:
|
||||
print(f"Warning: {ckpt_path} not found. Using untrained model.")
|
||||
|
||||
# 测试模型
|
||||
model.eval()
|
||||
all_ids = []
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
test_loss = 0.0
|
||||
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
pbar = tqdm(test_loader, desc="Testing")
|
||||
with torch.inference_mode():
|
||||
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||
all_ids.extend(ids)
|
||||
labels = labels.to(device)
|
||||
|
||||
# 处理输入数据
|
||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
loss = loss_func(outputs, labels)
|
||||
|
||||
test_loss += loss.item()
|
||||
|
||||
# 计算预测结果
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy()
|
||||
all_preds.extend(preds)
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
# 清理内存
|
||||
del texts, labels, outputs, loss
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 计算评估指标
|
||||
avg_loss = test_loss / len(test_loader)
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
precision = precision_score(all_labels, all_preds, average="weighted")
|
||||
recall = recall_score(all_labels, all_preds, average="weighted")
|
||||
f1 = f1_score(all_labels, all_preds, average="weighted")
|
||||
cm = confusion_matrix(all_labels, all_preds)
|
||||
|
||||
# 打印评估结果
|
||||
print("\n=== Test Results ===")
|
||||
print(f"Average Loss: {avg_loss:.4f}")
|
||||
print(f"Accuracy: {accuracy:.4f}")
|
||||
print(f"Precision: {precision:.4f}")
|
||||
print(f"Recall: {recall:.4f}")
|
||||
print(f"F1 Score: {f1:.4f}")
|
||||
print("\nConfusion Matrix:")
|
||||
print(cm)
|
||||
print("\n=== Class-wise Metrics ===")
|
||||
print("Class 0 (Not Deal):")
|
||||
print(f" Precision: {precision_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||
print(f" Recall: {recall_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||
print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||
print("\nClass 1 (Deal):")
|
||||
print(f" Precision: {precision_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||
print(f" Recall: {recall_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||
print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||
|
||||
# 保存测试结果
|
||||
test_results = {
|
||||
"average_loss": avg_loss,
|
||||
"accuracy": accuracy,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1,
|
||||
"confusion_matrix": cm.tolist(),
|
||||
"class_0_precision": precision_score(all_labels, all_preds, average=None)[0],
|
||||
"class_0_recall": recall_score(all_labels, all_preds, average=None)[0],
|
||||
"class_0_f1": f1_score(all_labels, all_preds, average=None)[0],
|
||||
"class_1_precision": precision_score(all_labels, all_preds, average=None)[1],
|
||||
"class_1_recall": recall_score(all_labels, all_preds, average=None)[1],
|
||||
"class_1_f1": f1_score(all_labels, all_preds, average=None)[1],
|
||||
"test_samples": len(all_labels),
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
# 保存预测结果
|
||||
pred_results = {
|
||||
"ids": all_ids,
|
||||
"predictions": all_preds,
|
||||
"true_labels": all_labels
|
||||
}
|
||||
pred_df = pd.DataFrame(pred_results)
|
||||
pred_df.to_csv("test_predictions.csv", index=False, encoding="utf-8")
|
||||
|
||||
# 保存为 JSON 文件
|
||||
with open("test_results.json", "w", encoding="utf-8") as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=2)
|
||||
print("\nTest results saved to test_results.json")
|
||||
return test_results
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
|
||||
deal_folder = "deal"
|
||||
not_deal_folder = "not_deal"
|
||||
batch_size = 8
|
||||
ckpt_path = "best_ckpt.pth"
|
||||
device = "cuda"
|
||||
|
||||
# 运行测试
|
||||
test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path, device)
|
||||
Reference in New Issue
Block a user