Files
deal-classification/train.py
2026-02-27 11:40:28 +08:00

323 lines
12 KiB
Python

from data_process import build_dataloader
from model import TransClassifier, FocalLoss
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
import pandas as pd
import numpy as np
import os
import json
from datetime import datetime
import gc
from tqdm import tqdm
import warnings
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
warnings.filterwarnings("ignore")
class EarlyStopping:
def __init__(self, patience=5, delta=0):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.inf
self.delta = delta
def __call__(self, val_loss, model, best_ckpt_path):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, best_ckpt_path)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, best_ckpt_path)
self.counter = 0
def save_checkpoint(self, val_loss, model, best_ckpt_path):
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), best_ckpt_path)
self.val_loss_min = val_loss
def train(backbone_dir, deal_folder, not_deal_folder,
batch_size, initial_lr=1e-5, max_epochs=100, threshold: int = 10,
device="cuda", use_focal_loss=False, balance=True):
best_ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
if use_focal_loss:
model = TransClassifier(backbone_dir, output_classes=1, device=device)
if balance:
loss_func = FocalLoss(
gamma=2.0,
alpha=0.5,
reduction='mean',
task_type='binary')
else:
loss_func = FocalLoss(
gamma=2.0,
alpha=0.8,
reduction='mean',
task_type='binary')
else:
model = TransClassifier(backbone_dir, output_classes=2, device=device)
loss_func = nn.CrossEntropyLoss()
assert balance == True, "When not using CE loss, balance must be True."
model.to(device)
data_dict = build_dataloader(deal_data_folder=deal_folder, not_deal_data_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, balance=balance)
train_loader = data_dict["train"]
val_loader = data_dict["val"]
test_loader = data_dict["test"]
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
early_stopping = EarlyStopping(patience=10, delta=0)
history = {"train_loss": [], "val_loss": [], "epoch": []}
for epoch in range(max_epochs):
model.train()
total_loss = 0.0
train_steps = 0
if epoch == 2:
for param in model.backbone.parameters():
param.requires_grad = True
print("Unfreeze backbone parameters")
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]')
for batch_idx, (ids, texts, labels) in enumerate(pbar):
labels = labels.to(device)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
if use_focal_loss:
outputs = outputs.squeeze(1)
loss = loss_func(outputs, labels.float())
else:
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
train_steps += 1
train_loss = total_loss / train_steps
pbar.set_postfix({"train_loss": train_loss})
del texts, labels, outputs, loss
torch.cuda.empty_cache()
gc.collect()
val_loss = val(val_loader, model, loss_func, tokenizer, device, use_focal_loss)
train_loss_epoch = total_loss / len(train_loader)
history["train_loss"].append(train_loss_epoch)
history["val_loss"].append(val_loss)
history["epoch"].append(epoch+1)
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
early_stopping(val_loss, model, best_ckpt_path)
if early_stopping.early_stop:
print("Early stopping")
break
history_df = pd.DataFrame(history)
history_df.to_csv(f"training_history_threshold_{threshold}.csv", index=False)
print(f"Training history saved to training_history_threshold_{threshold}.csv")
return test_loader
def val(val_loader, model, loss_func, tokenizer, device, use_focal_loss=False):
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch_idx, (ids, texts, labels) in enumerate(val_loader):
labels = labels.to(device)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
if use_focal_loss:
outputs = outputs.squeeze(1)
loss = loss_func(outputs, labels.float())
else:
loss = loss_func(outputs, labels)
val_loss += loss.item()
del inputs, outputs, labels, loss
gc.collect()
torch.cuda.empty_cache()
return val_loss / len(val_loader)
def test(backbone_dir, test_loader, device, threshold, use_focal_loss=False, balance=True):
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
if use_focal_loss:
model = TransClassifier(backbone_dir, output_classes=1, device=device)
else:
model = TransClassifier(backbone_dir, output_classes=2, device=device)
model.to(device)
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
if os.path.exists(ckpt_path):
model.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Model loaded from {ckpt_path}")
else:
print(f"Warning: {ckpt_path} not found. Using untrained model.")
model.eval()
all_ids = []
all_preds = []
all_probs = []
all_labels = []
pbar = tqdm(test_loader, desc="Testing")
with torch.inference_mode():
for batch_idx, (ids, texts, labels) in enumerate(pbar):
all_ids.extend(ids)
labels = labels.to(device)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
if use_focal_loss:
outputs = outputs.squeeze(-1) # [B, 1] -> [B]
if use_focal_loss:
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.sigmoid(outputs_float).cpu().numpy().tolist() # [B]
preds = [1 if p >= 0.5 else 0 for p in probs]
else:
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs]
all_preds.extend(preds)
all_probs.extend(probs)
all_labels.extend(labels.cpu().numpy())
# 清理内存
del texts, labels, outputs
torch.cuda.empty_cache()
gc.collect()
# 计算评估指标
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average="weighted")
recall = recall_score(all_labels, all_preds, average="weighted")
f1 = f1_score(all_labels, all_preds, average="weighted")
auc = roc_auc_score(all_labels, all_probs)
cm = confusion_matrix(all_labels, all_preds)
# 打印评估结果
print("\n=== Test Results ===")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"AUC: {auc:.4f}")
cm_df = pd.DataFrame(cm,
index=['Actual Not Deal (0)', 'Actual Deal (1)'],
columns=['Predicted Not Deal (0)', 'Predicted Deal (1)'])
print("\nConfusion Matrix:")
print(cm_df)
precision_per_class = precision_score(all_labels, all_preds, average=None)
recall_per_class = recall_score(all_labels, all_preds, average=None)
f1_per_class = f1_score(all_labels, all_preds, average=None)
print("\n=== Class-wise Metrics ===")
print("Class 0 (Not Deal):")
print(f" Precision: {precision_per_class[0]:.4f}")
print(f" Recall: {recall_per_class[0]:.4f}")
print(f" F1 Score: {f1_per_class[0]:.4f}")
print("\nClass 1 (Deal):")
print(f" Precision: {precision_per_class[1]:.4f}")
print(f" Recall: {recall_per_class[1]:.4f}")
print(f" F1 Score: {f1_per_class[1]:.4f}")
test_results = {
"accuracy": accuracy,
"precision_weighted": precision,
"recall_weighted": recall,
"f1_weighted": f1,
"auc": auc,
"confusion_matrix": cm.tolist(),
"class_0_precision": precision_per_class[0],
"class_0_recall": recall_per_class[0],
"class_0_f1": f1_per_class[0],
"class_1_precision": precision_per_class[1],
"class_1_recall": recall_per_class[1],
"class_1_f1": f1_per_class[1],
"test_samples": len(all_labels)
}
with open(f"test_results_threshold_{threshold}.json", "w", encoding="utf-8") as f:
json.dump(test_results, f, ensure_ascii=False, indent=4)
print(f"\nTest results saved to test_results_threshold_{threshold}.json")
pred_df = pd.DataFrame({
"ids": all_ids,
"predictions": all_preds,
"probability": all_probs,
"true_labels": all_labels
})
pred_df.to_csv(f"test_predictions_threshold_{threshold}.csv", index=False, encoding="utf-8")
if __name__ == "__main__":
backbone_dir = "Qwen3-1.7B"
deal_folder = "filtered_deal"
not_deal_folder = "filtered_not_deal"
batch_size = 4
device = "cuda"
""" threshold = 10
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, device=device, use_focal_loss=False, balance=True)
test(
backbone_dir=backbone_dir,
test_loader=test_loader,
device=device,
threshold=threshold,
use_focal_loss=False,
balance=True
) """
max_threshold = 10
for i in range(3, 9):
print(f"Training with threshold {i}...")
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=i, device=device, use_focal_loss=False, balance=True)
test(
backbone_dir=backbone_dir,
test_loader=test_loader,
device=device,
threshold=i,
use_focal_loss=False,
balance=True
)