130 lines
5.1 KiB
Python
130 lines
5.1 KiB
Python
from data_process import build_offline_dataloader
|
|
from model import TransClassifier
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import os
|
|
import json
|
|
import gc
|
|
from tqdm import tqdm
|
|
import warnings
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
|
|
from transformers import AutoTokenizer
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
def offline_test(
|
|
deal_data_folder, not_deal_data_folder, batch_size, threshold,
|
|
backbone_dir, ckpt_path, device, filtered=False
|
|
):
|
|
offline_loader = build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
|
model = TransClassifier(backbone_dir, 2, device)
|
|
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
|
model.eval()
|
|
|
|
all_ids = []
|
|
all_preds = []
|
|
all_probs = []
|
|
all_labels = []
|
|
|
|
pbar = tqdm(offline_loader, desc="Testing")
|
|
with torch.inference_mode():
|
|
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
|
all_ids.extend(ids)
|
|
labels = labels.to(device)
|
|
|
|
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
|
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
|
|
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
outputs = model(inputs)
|
|
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
|
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
|
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
|
probs = probs.cpu().numpy().tolist()
|
|
probs = [p[1] for p in probs]
|
|
|
|
all_preds.extend(preds)
|
|
all_probs.extend(probs)
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
# 清理内存
|
|
del texts, labels, outputs
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# 计算评估指标
|
|
accuracy = accuracy_score(all_labels, all_preds)
|
|
precision = precision_score(all_labels, all_preds, average="weighted")
|
|
recall = recall_score(all_labels, all_preds, average="weighted")
|
|
f1 = f1_score(all_labels, all_preds, average="weighted")
|
|
auc = roc_auc_score(all_labels, all_probs)
|
|
cm = confusion_matrix(all_labels, all_preds)
|
|
|
|
precision_per_class = precision_score(all_labels, all_preds, average=None)
|
|
recall_per_class = recall_score(all_labels, all_preds, average=None)
|
|
f1_per_class = f1_score(all_labels, all_preds, average=None)
|
|
|
|
test_results = {
|
|
"accuracy": accuracy,
|
|
"precision_weighted": precision,
|
|
"recall_weighted": recall,
|
|
"f1_weighted": f1,
|
|
"auc": auc,
|
|
"confusion_matrix": cm.tolist(),
|
|
"class_0_precision": precision_per_class[0],
|
|
"class_0_recall": recall_per_class[0],
|
|
"class_0_f1": f1_per_class[0],
|
|
"class_1_precision": precision_per_class[1],
|
|
"class_1_recall": recall_per_class[1],
|
|
"class_1_f1": f1_per_class[1]
|
|
}
|
|
|
|
if filtered:
|
|
with open(f"offline_test_result_filtered_{threshold}.json", "w", encoding="utf-8") as f:
|
|
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
|
else:
|
|
with open(f"offline_test_result_{threshold}.json", "w", encoding="utf-8") as f:
|
|
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
|
|
|
pred_df = pd.DataFrame({
|
|
"ids": all_ids,
|
|
"predictions": all_preds,
|
|
"probability": all_probs,
|
|
"true_labels": all_labels
|
|
})
|
|
if filtered:
|
|
pred_df.to_csv(f"offline_test_predictions_filtered_{threshold}.csv", index=False, encoding="utf-8")
|
|
else:
|
|
pred_df.to_csv(f"offline_test_predictions_{threshold}.csv", index=False, encoding="utf-8")
|
|
|
|
if __name__=="__main__":
|
|
filtered_deal_folder = "not_trained_deal_filtered"
|
|
filtered_not_deal_folder = "not_trained_not_deal_filtered"
|
|
|
|
deal_folder = "not_trained_deal"
|
|
not_deal_folder = "not_trained_not_deal"
|
|
|
|
batch_size = 8
|
|
backbone_dir = "Qwen3-1.7B"
|
|
|
|
""" for i in range(3, 9):
|
|
threshold = i
|
|
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False)
|
|
offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True) """
|
|
|
|
threshold = 5
|
|
ckpt_path = f"best_ckpt_threshold_{threshold}_1st.pth"
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False)
|
|
offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True) |