From 39bd04f0e00c20bb23a57276af1a6329fe56fdf5 Mon Sep 17 00:00:00 2001 From: WangZiFan Date: Fri, 27 Feb 2026 11:39:58 +0800 Subject: [PATCH] Add offline_test.py --- offline_test.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 offline_test.py diff --git a/offline_test.py b/offline_test.py new file mode 100644 index 0000000..32e8d6a --- /dev/null +++ b/offline_test.py @@ -0,0 +1,130 @@ +from data_process import build_offline_dataloader +from model import TransClassifier + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import pandas as pd +import numpy as np +import os +import json +import gc +from tqdm import tqdm +import warnings +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score +from transformers import AutoTokenizer +warnings.filterwarnings("ignore") + + +def offline_test( + deal_data_folder, not_deal_data_folder, batch_size, threshold, + backbone_dir, ckpt_path, device, filtered=False + ): + offline_loader = build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold) + + tokenizer = AutoTokenizer.from_pretrained(backbone_dir) + model = TransClassifier(backbone_dir, 2, device) + model.load_state_dict(torch.load(ckpt_path, map_location=device)) + model.eval() + + all_ids = [] + all_preds = [] + all_probs = [] + all_labels = [] + + pbar = tqdm(offline_loader, desc="Testing") + with torch.inference_mode(): + for batch_idx, (ids, texts, labels) in enumerate(pbar): + all_ids.extend(ids) + labels = labels.to(device) + + texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) + inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) + + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(inputs) + preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist() + outputs_float = outputs.float() # 转换为 float32 避免精度问题 + probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] + probs = probs.cpu().numpy().tolist() + probs = [p[1] for p in probs] + + all_preds.extend(preds) + all_probs.extend(probs) + all_labels.extend(labels.cpu().numpy()) + + # 清理内存 + del texts, labels, outputs + torch.cuda.empty_cache() + gc.collect() + + # 计算评估指标 + accuracy = accuracy_score(all_labels, all_preds) + precision = precision_score(all_labels, all_preds, average="weighted") + recall = recall_score(all_labels, all_preds, average="weighted") + f1 = f1_score(all_labels, all_preds, average="weighted") + auc = roc_auc_score(all_labels, all_probs) + cm = confusion_matrix(all_labels, all_preds) + + precision_per_class = precision_score(all_labels, all_preds, average=None) + recall_per_class = recall_score(all_labels, all_preds, average=None) + f1_per_class = f1_score(all_labels, all_preds, average=None) + + test_results = { + "accuracy": accuracy, + "precision_weighted": precision, + "recall_weighted": recall, + "f1_weighted": f1, + "auc": auc, + "confusion_matrix": cm.tolist(), + "class_0_precision": precision_per_class[0], + "class_0_recall": recall_per_class[0], + "class_0_f1": f1_per_class[0], + "class_1_precision": precision_per_class[1], + "class_1_recall": recall_per_class[1], + "class_1_f1": f1_per_class[1] + } + + if filtered: + with open(f"offline_test_result_filtered_{threshold}.json", "w", encoding="utf-8") as f: + json.dump(test_results, f, ensure_ascii=False, indent=4) + else: + with open(f"offline_test_result_{threshold}.json", "w", encoding="utf-8") as f: + json.dump(test_results, f, ensure_ascii=False, indent=4) + + pred_df = pd.DataFrame({ + "ids": all_ids, + "predictions": all_preds, + "probability": all_probs, + "true_labels": all_labels + }) + if filtered: + pred_df.to_csv(f"offline_test_predictions_filtered_{threshold}.csv", index=False, encoding="utf-8") + else: + pred_df.to_csv(f"offline_test_predictions_{threshold}.csv", index=False, encoding="utf-8") + +if __name__=="__main__": + filtered_deal_folder = "not_trained_deal_filtered" + filtered_not_deal_folder = "not_trained_not_deal_filtered" + + deal_folder = "not_trained_deal" + not_deal_folder = "not_trained_not_deal" + + batch_size = 8 + backbone_dir = "Qwen3-1.7B" + + """ for i in range(3, 9): + threshold = i + ckpt_path = f"best_ckpt_threshold_{threshold}.pth" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + + offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False) + offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True) """ + + threshold = 5 + ckpt_path = f"best_ckpt_threshold_{threshold}_1st.pth" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False) + offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True) \ No newline at end of file