Update train.py

This commit is contained in:
2026-02-27 11:40:28 +08:00
parent 39bd04f0e0
commit b00247cfb7

472
train.py
View File

@@ -1,149 +1,323 @@
from data_process import build_dataloader from data_process import build_dataloader
from model import TransClassifier from model import TransClassifier, FocalLoss
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoTokenizer from transformers import AutoTokenizer
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import os import os
import json import json
from datetime import datetime from datetime import datetime
import gc import gc
from tqdm import tqdm from tqdm import tqdm
import warnings import warnings
warnings.filterwarnings("ignore") from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
warnings.filterwarnings("ignore")
class EarlyStopping:
def __init__(self, patience=5, delta=0, path='checkpoint.pt'): class EarlyStopping:
self.patience = patience def __init__(self, patience=5, delta=0):
self.counter = 0 self.patience = patience
self.best_score = None self.counter = 0
self.early_stop = False self.best_score = None
self.val_loss_min = np.inf self.early_stop = False
self.delta = delta self.val_loss_min = np.inf
self.path = path self.delta = delta
def __call__(self, val_loss, model): def __call__(self, val_loss, model, best_ckpt_path):
score = -val_loss score = -val_loss
if self.best_score is None: if self.best_score is None:
self.best_score = score self.best_score = score
self.save_checkpoint(val_loss, model) self.save_checkpoint(val_loss, model, best_ckpt_path)
elif score < self.best_score + self.delta:
self.counter += 1 elif score < self.best_score + self.delta:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') self.counter += 1
if self.counter >= self.patience: print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
self.early_stop = True if self.counter >= self.patience:
else: self.early_stop = True
self.best_score = score else:
self.save_checkpoint(val_loss, model) self.best_score = score
self.counter = 0 self.save_checkpoint(val_loss, model, best_ckpt_path)
self.counter = 0
def save_checkpoint(self, val_loss, model):
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') def save_checkpoint(self, val_loss, model, best_ckpt_path):
torch.save(model.state_dict(), self.path) print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
self.val_loss_min = val_loss torch.save(model.state_dict(), best_ckpt_path)
self.val_loss_min = val_loss
def train(backbone_dir, deal_folder, not_deal_folder,
batch_size, initial_lr=1e-5, max_epochs=100, def train(backbone_dir, deal_folder, not_deal_folder,
best_ckpt_path="best_ckpt.pth", final_ckpt_path="final_ckpt.pth", device="cuda"): batch_size, initial_lr=1e-5, max_epochs=100, threshold: int = 10,
device="cuda", use_focal_loss=False, balance=True):
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size) best_ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
train_loader = data_dict["train"]
val_loader = data_dict["val"] tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
if use_focal_loss:
tokenizer = AutoTokenizer.from_pretrained(backbone_dir) model = TransClassifier(backbone_dir, output_classes=1, device=device)
model = TransClassifier(backbone_dir, device) if balance:
model.to(device) loss_func = FocalLoss(
gamma=2.0,
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr) alpha=0.5,
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) reduction='mean',
task_type='binary')
loss_func = nn.CrossEntropyLoss() else:
loss_func = FocalLoss(
early_stopping = EarlyStopping(path=best_ckpt_path) gamma=2.0,
history = {"train_loss": [], "val_loss": [], "epoch": []} alpha=0.8,
reduction='mean',
for epoch in range(max_epochs): task_type='binary')
model.train() else:
total_loss = 0.0 model = TransClassifier(backbone_dir, output_classes=2, device=device)
train_steps = 0 loss_func = nn.CrossEntropyLoss()
assert balance == True, "When not using CE loss, balance must be True."
if epoch == 2: model.to(device)
for param in model.backbone.parameters():
param.requires_grad = True data_dict = build_dataloader(deal_data_folder=deal_folder, not_deal_data_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, balance=balance)
print("Unfreeze backbone parameters") train_loader = data_dict["train"]
val_loader = data_dict["val"]
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]') test_loader = data_dict["test"]
for batch_idx, (ids, texts, labels) in enumerate(pbar):
labels = labels.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) early_stopping = EarlyStopping(patience=10, delta=0)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): history = {"train_loss": [], "val_loss": [], "epoch": []}
outputs = model(inputs)
loss = loss_func(outputs, labels) for epoch in range(max_epochs):
model.train()
optimizer.zero_grad() total_loss = 0.0
loss.backward() train_steps = 0
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() if epoch == 2:
scheduler.step() for param in model.backbone.parameters():
param.requires_grad = True
total_loss += loss.item() print("Unfreeze backbone parameters")
train_steps += 1
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]')
train_loss = total_loss / train_steps for batch_idx, (ids, texts, labels) in enumerate(pbar):
pbar.set_postfix({"train_loss": train_loss}) labels = labels.to(device)
del texts, labels, outputs, loss texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
torch.cuda.empty_cache() inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
gc.collect()
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
val_loss = val(val_loader, model, loss_func, tokenizer, device) outputs = model(inputs)
history["train_loss"].append(total_loss / len(train_loader)) if use_focal_loss:
history["val_loss"].append(val_loss) outputs = outputs.squeeze(1)
history["epoch"].append(epoch+1) loss = loss_func(outputs, labels.float())
else:
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}") loss = loss_func(outputs, labels)
early_stopping(val_loss, model) optimizer.zero_grad()
if early_stopping.early_stop: loss.backward()
print("Early stopping") nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
break optimizer.step()
scheduler.step()
torch.save(model.state_dict(), final_ckpt_path)
print(f"Final model saved to {final_ckpt_path}") total_loss += loss.item()
train_steps += 1
history_df = pd.DataFrame(history)
history_df.to_csv("training_history.csv", index=False) train_loss = total_loss / train_steps
print("Training history saved to training_history.csv") pbar.set_postfix({"train_loss": train_loss})
def val(val_loader, model, loss_func, tokenizer, device): del texts, labels, outputs, loss
model.eval() torch.cuda.empty_cache()
val_loss = 0.0 gc.collect()
with torch.no_grad():
for batch_idx, (ids, texts, labels) in enumerate(val_loader): val_loss = val(val_loader, model, loss_func, tokenizer, device, use_focal_loss)
labels = labels.to(device) train_loss_epoch = total_loss / len(train_loader)
history["train_loss"].append(train_loss_epoch)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) history["val_loss"].append(val_loss)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) history["epoch"].append(epoch+1)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs) print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
loss = loss_func(outputs, labels)
early_stopping(val_loss, model, best_ckpt_path)
val_loss += loss.item() if early_stopping.early_stop:
return val_loss / len(val_loader) print("Early stopping")
break
if __name__ == "__main__":
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B" history_df = pd.DataFrame(history)
deal_folder = "deal" history_df.to_csv(f"training_history_threshold_{threshold}.csv", index=False)
not_deal_folder = "not_deal" print(f"Training history saved to training_history_threshold_{threshold}.csv")
batch_size = 8 return test_loader
device = "cuda"
def val(val_loader, model, loss_func, tokenizer, device, use_focal_loss=False):
model.eval()
train(backbone_dir, deal_folder, not_deal_folder, batch_size, device=device) val_loss = 0.0
with torch.no_grad():
for batch_idx, (ids, texts, labels) in enumerate(val_loader):
labels = labels.to(device)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
if use_focal_loss:
outputs = outputs.squeeze(1)
loss = loss_func(outputs, labels.float())
else:
loss = loss_func(outputs, labels)
val_loss += loss.item()
del inputs, outputs, labels, loss
gc.collect()
torch.cuda.empty_cache()
return val_loss / len(val_loader)
def test(backbone_dir, test_loader, device, threshold, use_focal_loss=False, balance=True):
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
if use_focal_loss:
model = TransClassifier(backbone_dir, output_classes=1, device=device)
else:
model = TransClassifier(backbone_dir, output_classes=2, device=device)
model.to(device)
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
if os.path.exists(ckpt_path):
model.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Model loaded from {ckpt_path}")
else:
print(f"Warning: {ckpt_path} not found. Using untrained model.")
model.eval()
all_ids = []
all_preds = []
all_probs = []
all_labels = []
pbar = tqdm(test_loader, desc="Testing")
with torch.inference_mode():
for batch_idx, (ids, texts, labels) in enumerate(pbar):
all_ids.extend(ids)
labels = labels.to(device)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
if use_focal_loss:
outputs = outputs.squeeze(-1) # [B, 1] -> [B]
if use_focal_loss:
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.sigmoid(outputs_float).cpu().numpy().tolist() # [B]
preds = [1 if p >= 0.5 else 0 for p in probs]
else:
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs]
all_preds.extend(preds)
all_probs.extend(probs)
all_labels.extend(labels.cpu().numpy())
# 清理内存
del texts, labels, outputs
torch.cuda.empty_cache()
gc.collect()
# 计算评估指标
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average="weighted")
recall = recall_score(all_labels, all_preds, average="weighted")
f1 = f1_score(all_labels, all_preds, average="weighted")
auc = roc_auc_score(all_labels, all_probs)
cm = confusion_matrix(all_labels, all_preds)
# 打印评估结果
print("\n=== Test Results ===")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"AUC: {auc:.4f}")
cm_df = pd.DataFrame(cm,
index=['Actual Not Deal (0)', 'Actual Deal (1)'],
columns=['Predicted Not Deal (0)', 'Predicted Deal (1)'])
print("\nConfusion Matrix:")
print(cm_df)
precision_per_class = precision_score(all_labels, all_preds, average=None)
recall_per_class = recall_score(all_labels, all_preds, average=None)
f1_per_class = f1_score(all_labels, all_preds, average=None)
print("\n=== Class-wise Metrics ===")
print("Class 0 (Not Deal):")
print(f" Precision: {precision_per_class[0]:.4f}")
print(f" Recall: {recall_per_class[0]:.4f}")
print(f" F1 Score: {f1_per_class[0]:.4f}")
print("\nClass 1 (Deal):")
print(f" Precision: {precision_per_class[1]:.4f}")
print(f" Recall: {recall_per_class[1]:.4f}")
print(f" F1 Score: {f1_per_class[1]:.4f}")
test_results = {
"accuracy": accuracy,
"precision_weighted": precision,
"recall_weighted": recall,
"f1_weighted": f1,
"auc": auc,
"confusion_matrix": cm.tolist(),
"class_0_precision": precision_per_class[0],
"class_0_recall": recall_per_class[0],
"class_0_f1": f1_per_class[0],
"class_1_precision": precision_per_class[1],
"class_1_recall": recall_per_class[1],
"class_1_f1": f1_per_class[1],
"test_samples": len(all_labels)
}
with open(f"test_results_threshold_{threshold}.json", "w", encoding="utf-8") as f:
json.dump(test_results, f, ensure_ascii=False, indent=4)
print(f"\nTest results saved to test_results_threshold_{threshold}.json")
pred_df = pd.DataFrame({
"ids": all_ids,
"predictions": all_preds,
"probability": all_probs,
"true_labels": all_labels
})
pred_df.to_csv(f"test_predictions_threshold_{threshold}.csv", index=False, encoding="utf-8")
if __name__ == "__main__":
backbone_dir = "Qwen3-1.7B"
deal_folder = "filtered_deal"
not_deal_folder = "filtered_not_deal"
batch_size = 4
device = "cuda"
""" threshold = 10
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, device=device, use_focal_loss=False, balance=True)
test(
backbone_dir=backbone_dir,
test_loader=test_loader,
device=device,
threshold=threshold,
use_focal_loss=False,
balance=True
) """
max_threshold = 10
for i in range(3, 9):
print(f"Training with threshold {i}...")
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=i, device=device, use_focal_loss=False, balance=True)
test(
backbone_dir=backbone_dir,
test_loader=test_loader,
device=device,
threshold=i,
use_focal_loss=False,
balance=True
)