Add offline_test.py
This commit is contained in:
130
offline_test.py
Normal file
130
offline_test.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from data_process import build_offline_dataloader
|
||||
from model import TransClassifier
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
|
||||
from transformers import AutoTokenizer
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def offline_test(
|
||||
deal_data_folder, not_deal_data_folder, batch_size, threshold,
|
||||
backbone_dir, ckpt_path, device, filtered=False
|
||||
):
|
||||
offline_loader = build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||
model = TransClassifier(backbone_dir, 2, device)
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
model.eval()
|
||||
|
||||
all_ids = []
|
||||
all_preds = []
|
||||
all_probs = []
|
||||
all_labels = []
|
||||
|
||||
pbar = tqdm(offline_loader, desc="Testing")
|
||||
with torch.inference_mode():
|
||||
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||
all_ids.extend(ids)
|
||||
labels = labels.to(device)
|
||||
|
||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs]
|
||||
|
||||
all_preds.extend(preds)
|
||||
all_probs.extend(probs)
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
# 清理内存
|
||||
del texts, labels, outputs
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 计算评估指标
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
precision = precision_score(all_labels, all_preds, average="weighted")
|
||||
recall = recall_score(all_labels, all_preds, average="weighted")
|
||||
f1 = f1_score(all_labels, all_preds, average="weighted")
|
||||
auc = roc_auc_score(all_labels, all_probs)
|
||||
cm = confusion_matrix(all_labels, all_preds)
|
||||
|
||||
precision_per_class = precision_score(all_labels, all_preds, average=None)
|
||||
recall_per_class = recall_score(all_labels, all_preds, average=None)
|
||||
f1_per_class = f1_score(all_labels, all_preds, average=None)
|
||||
|
||||
test_results = {
|
||||
"accuracy": accuracy,
|
||||
"precision_weighted": precision,
|
||||
"recall_weighted": recall,
|
||||
"f1_weighted": f1,
|
||||
"auc": auc,
|
||||
"confusion_matrix": cm.tolist(),
|
||||
"class_0_precision": precision_per_class[0],
|
||||
"class_0_recall": recall_per_class[0],
|
||||
"class_0_f1": f1_per_class[0],
|
||||
"class_1_precision": precision_per_class[1],
|
||||
"class_1_recall": recall_per_class[1],
|
||||
"class_1_f1": f1_per_class[1]
|
||||
}
|
||||
|
||||
if filtered:
|
||||
with open(f"offline_test_result_filtered_{threshold}.json", "w", encoding="utf-8") as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
||||
else:
|
||||
with open(f"offline_test_result_{threshold}.json", "w", encoding="utf-8") as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
||||
|
||||
pred_df = pd.DataFrame({
|
||||
"ids": all_ids,
|
||||
"predictions": all_preds,
|
||||
"probability": all_probs,
|
||||
"true_labels": all_labels
|
||||
})
|
||||
if filtered:
|
||||
pred_df.to_csv(f"offline_test_predictions_filtered_{threshold}.csv", index=False, encoding="utf-8")
|
||||
else:
|
||||
pred_df.to_csv(f"offline_test_predictions_{threshold}.csv", index=False, encoding="utf-8")
|
||||
|
||||
if __name__=="__main__":
|
||||
filtered_deal_folder = "not_trained_deal_filtered"
|
||||
filtered_not_deal_folder = "not_trained_not_deal_filtered"
|
||||
|
||||
deal_folder = "not_trained_deal"
|
||||
not_deal_folder = "not_trained_not_deal"
|
||||
|
||||
batch_size = 8
|
||||
backbone_dir = "Qwen3-1.7B"
|
||||
|
||||
""" for i in range(3, 9):
|
||||
threshold = i
|
||||
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False)
|
||||
offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True) """
|
||||
|
||||
threshold = 5
|
||||
ckpt_path = f"best_ckpt_threshold_{threshold}_1st.pth"
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False)
|
||||
offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True)
|
||||
Reference in New Issue
Block a user