diff --git a/test.py b/test.py deleted file mode 100644 index 8e44e3f..0000000 --- a/test.py +++ /dev/null @@ -1,154 +0,0 @@ -from data_process import build_dataloader -from model import TransClassifier - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import AutoTokenizer - -import pandas as pd -import numpy as np -import os -import json -from datetime import datetime -import gc -from tqdm import tqdm -import warnings -from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix -warnings.filterwarnings("ignore") - -def test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path="best_ckpt.pth", device="cuda"): - """ - 测试模型在测试集上的表现 - - Args: - backbone_dir: 预训练模型目录 - deal_folder: 成交数据文件夹 - not_deal_folder: 非成交数据文件夹 - batch_size: 批量大小 - ckpt_path: 模型 checkpoint 路径 - device: 运行设备 - """ - # 加载测试数据 - data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size) - test_loader = data_dict["test"] - print(f"Test data loaded successfully. Test samples: {len(test_loader.dataset)}") - - # 加载 tokenizer 和模型 - tokenizer = AutoTokenizer.from_pretrained(backbone_dir) - model = TransClassifier(backbone_dir, device) - model.to(device) - - # 加载训练好的模型权重 - if os.path.exists(ckpt_path): - model.load_state_dict(torch.load(ckpt_path, map_location=device)) - print(f"Model loaded from {ckpt_path}") - else: - print(f"Warning: {ckpt_path} not found. Using untrained model.") - - # 测试模型 - model.eval() - all_ids = [] - all_preds = [] - all_labels = [] - test_loss = 0.0 - - loss_func = nn.CrossEntropyLoss() - - pbar = tqdm(test_loader, desc="Testing") - with torch.inference_mode(): - for batch_idx, (ids, texts, labels) in enumerate(pbar): - all_ids.extend(ids) - labels = labels.to(device) - - # 处理输入数据 - texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) - inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) - - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - outputs = model(inputs) - loss = loss_func(outputs, labels) - - test_loss += loss.item() - - # 计算预测结果 - preds = torch.argmax(outputs, dim=1).cpu().numpy() - all_preds.extend(preds) - all_labels.extend(labels.cpu().numpy()) - - # 清理内存 - del texts, labels, outputs, loss - torch.cuda.empty_cache() - gc.collect() - - # 计算评估指标 - avg_loss = test_loss / len(test_loader) - accuracy = accuracy_score(all_labels, all_preds) - precision = precision_score(all_labels, all_preds, average="weighted") - recall = recall_score(all_labels, all_preds, average="weighted") - f1 = f1_score(all_labels, all_preds, average="weighted") - cm = confusion_matrix(all_labels, all_preds) - - # 打印评估结果 - print("\n=== Test Results ===") - print(f"Average Loss: {avg_loss:.4f}") - print(f"Accuracy: {accuracy:.4f}") - print(f"Precision: {precision:.4f}") - print(f"Recall: {recall:.4f}") - print(f"F1 Score: {f1:.4f}") - print("\nConfusion Matrix:") - print(cm) - print("\n=== Class-wise Metrics ===") - print("Class 0 (Not Deal):") - print(f" Precision: {precision_score(all_labels, all_preds, average=None)[0]:.4f}") - print(f" Recall: {recall_score(all_labels, all_preds, average=None)[0]:.4f}") - print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[0]:.4f}") - print("\nClass 1 (Deal):") - print(f" Precision: {precision_score(all_labels, all_preds, average=None)[1]:.4f}") - print(f" Recall: {recall_score(all_labels, all_preds, average=None)[1]:.4f}") - print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[1]:.4f}") - - # 保存测试结果 - test_results = { - "average_loss": avg_loss, - "accuracy": accuracy, - "precision": precision, - "recall": recall, - "f1_score": f1, - "confusion_matrix": cm.tolist(), - "class_0_precision": precision_score(all_labels, all_preds, average=None)[0], - "class_0_recall": recall_score(all_labels, all_preds, average=None)[0], - "class_0_f1": f1_score(all_labels, all_preds, average=None)[0], - "class_1_precision": precision_score(all_labels, all_preds, average=None)[1], - "class_1_recall": recall_score(all_labels, all_preds, average=None)[1], - "class_1_f1": f1_score(all_labels, all_preds, average=None)[1], - "test_samples": len(all_labels), - "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") - } - - # 保存预测结果 - pred_results = { - "ids": all_ids, - "predictions": all_preds, - "true_labels": all_labels - } - pred_df = pd.DataFrame(pred_results) - pred_df.to_csv("test_predictions.csv", index=False, encoding="utf-8") - - # 保存为 JSON 文件 - with open("test_results.json", "w", encoding="utf-8") as f: - json.dump(test_results, f, ensure_ascii=False, indent=2) - print("\nTest results saved to test_results.json") - return test_results - -if __name__ == "__main__": - # 配置参数 - backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B" - deal_folder = "deal" - not_deal_folder = "not_deal" - batch_size = 8 - ckpt_path = "best_ckpt.pth" - device = "cuda" - - # 运行测试 - test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path, device) \ No newline at end of file