Files
deal-classification/offline_test.py
2026-02-27 11:39:58 +08:00

130 lines
5.1 KiB
Python

from data_process import build_offline_dataloader
from model import TransClassifier
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import os
import json
import gc
from tqdm import tqdm
import warnings
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
from transformers import AutoTokenizer
warnings.filterwarnings("ignore")
def offline_test(
deal_data_folder, not_deal_data_folder, batch_size, threshold,
backbone_dir, ckpt_path, device, filtered=False
):
offline_loader = build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold)
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
model = TransClassifier(backbone_dir, 2, device)
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()
all_ids = []
all_preds = []
all_probs = []
all_labels = []
pbar = tqdm(offline_loader, desc="Testing")
with torch.inference_mode():
for batch_idx, (ids, texts, labels) in enumerate(pbar):
all_ids.extend(ids)
labels = labels.to(device)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs]
all_preds.extend(preds)
all_probs.extend(probs)
all_labels.extend(labels.cpu().numpy())
# 清理内存
del texts, labels, outputs
torch.cuda.empty_cache()
gc.collect()
# 计算评估指标
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average="weighted")
recall = recall_score(all_labels, all_preds, average="weighted")
f1 = f1_score(all_labels, all_preds, average="weighted")
auc = roc_auc_score(all_labels, all_probs)
cm = confusion_matrix(all_labels, all_preds)
precision_per_class = precision_score(all_labels, all_preds, average=None)
recall_per_class = recall_score(all_labels, all_preds, average=None)
f1_per_class = f1_score(all_labels, all_preds, average=None)
test_results = {
"accuracy": accuracy,
"precision_weighted": precision,
"recall_weighted": recall,
"f1_weighted": f1,
"auc": auc,
"confusion_matrix": cm.tolist(),
"class_0_precision": precision_per_class[0],
"class_0_recall": recall_per_class[0],
"class_0_f1": f1_per_class[0],
"class_1_precision": precision_per_class[1],
"class_1_recall": recall_per_class[1],
"class_1_f1": f1_per_class[1]
}
if filtered:
with open(f"offline_test_result_filtered_{threshold}.json", "w", encoding="utf-8") as f:
json.dump(test_results, f, ensure_ascii=False, indent=4)
else:
with open(f"offline_test_result_{threshold}.json", "w", encoding="utf-8") as f:
json.dump(test_results, f, ensure_ascii=False, indent=4)
pred_df = pd.DataFrame({
"ids": all_ids,
"predictions": all_preds,
"probability": all_probs,
"true_labels": all_labels
})
if filtered:
pred_df.to_csv(f"offline_test_predictions_filtered_{threshold}.csv", index=False, encoding="utf-8")
else:
pred_df.to_csv(f"offline_test_predictions_{threshold}.csv", index=False, encoding="utf-8")
if __name__=="__main__":
filtered_deal_folder = "not_trained_deal_filtered"
filtered_not_deal_folder = "not_trained_not_deal_filtered"
deal_folder = "not_trained_deal"
not_deal_folder = "not_trained_not_deal"
batch_size = 8
backbone_dir = "Qwen3-1.7B"
""" for i in range(3, 9):
threshold = i
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False)
offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True) """
threshold = 5
ckpt_path = f"best_ckpt_threshold_{threshold}_1st.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False)
offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True)