from data_process import build_offline_dataloader from model import TransClassifier import torch import torch.nn as nn import torch.nn.functional as F import pandas as pd import numpy as np import os import json import gc from tqdm import tqdm import warnings from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score from transformers import AutoTokenizer warnings.filterwarnings("ignore") def offline_test( deal_data_folder, not_deal_data_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False ): offline_loader = build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold) tokenizer = AutoTokenizer.from_pretrained(backbone_dir) model = TransClassifier(backbone_dir, 2, device) model.load_state_dict(torch.load(ckpt_path, map_location=device)) model.eval() all_ids = [] all_preds = [] all_probs = [] all_labels = [] pbar = tqdm(offline_loader, desc="Testing") with torch.inference_mode(): for batch_idx, (ids, texts, labels) in enumerate(pbar): all_ids.extend(ids) labels = labels.to(device) texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): outputs = model(inputs) preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist() outputs_float = outputs.float() # 转换为 float32 避免精度问题 probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] probs = probs.cpu().numpy().tolist() probs = [p[1] for p in probs] all_preds.extend(preds) all_probs.extend(probs) all_labels.extend(labels.cpu().numpy()) # 清理内存 del texts, labels, outputs torch.cuda.empty_cache() gc.collect() # 计算评估指标 accuracy = accuracy_score(all_labels, all_preds) precision = precision_score(all_labels, all_preds, average="weighted") recall = recall_score(all_labels, all_preds, average="weighted") f1 = f1_score(all_labels, all_preds, average="weighted") auc = roc_auc_score(all_labels, all_probs) cm = confusion_matrix(all_labels, all_preds) precision_per_class = precision_score(all_labels, all_preds, average=None) recall_per_class = recall_score(all_labels, all_preds, average=None) f1_per_class = f1_score(all_labels, all_preds, average=None) test_results = { "accuracy": accuracy, "precision_weighted": precision, "recall_weighted": recall, "f1_weighted": f1, "auc": auc, "confusion_matrix": cm.tolist(), "class_0_precision": precision_per_class[0], "class_0_recall": recall_per_class[0], "class_0_f1": f1_per_class[0], "class_1_precision": precision_per_class[1], "class_1_recall": recall_per_class[1], "class_1_f1": f1_per_class[1] } if filtered: with open(f"offline_test_result_filtered_{threshold}.json", "w", encoding="utf-8") as f: json.dump(test_results, f, ensure_ascii=False, indent=4) else: with open(f"offline_test_result_{threshold}.json", "w", encoding="utf-8") as f: json.dump(test_results, f, ensure_ascii=False, indent=4) pred_df = pd.DataFrame({ "ids": all_ids, "predictions": all_preds, "probability": all_probs, "true_labels": all_labels }) if filtered: pred_df.to_csv(f"offline_test_predictions_filtered_{threshold}.csv", index=False, encoding="utf-8") else: pred_df.to_csv(f"offline_test_predictions_{threshold}.csv", index=False, encoding="utf-8") if __name__=="__main__": filtered_deal_folder = "not_trained_deal_filtered" filtered_not_deal_folder = "not_trained_not_deal_filtered" deal_folder = "not_trained_deal" not_deal_folder = "not_trained_not_deal" batch_size = 8 backbone_dir = "Qwen3-1.7B" """ for i in range(3, 9): threshold = i ckpt_path = f"best_ckpt_threshold_{threshold}.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False) offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True) """ threshold = 5 ckpt_path = f"best_ckpt_threshold_{threshold}_1st.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False) offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True)