Update train.py

This commit is contained in:
2026-02-27 11:40:28 +08:00
parent 39bd04f0e0
commit b00247cfb7

238
train.py
View File

@@ -1,5 +1,5 @@
from data_process import build_dataloader
from model import TransClassifier
from model import TransClassifier, FocalLoss
import torch
import torch.nn as nn
@@ -14,24 +14,25 @@ from datetime import datetime
import gc
from tqdm import tqdm
import warnings
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
warnings.filterwarnings("ignore")
class EarlyStopping:
def __init__(self, patience=5, delta=0, path='checkpoint.pt'):
def __init__(self, patience=5, delta=0):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.inf
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
def __call__(self, val_loss, model, best_ckpt_path):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.save_checkpoint(val_loss, model, best_ckpt_path)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
@@ -39,32 +40,49 @@ class EarlyStopping:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.save_checkpoint(val_loss, model, best_ckpt_path)
self.counter = 0
def save_checkpoint(self, val_loss, model):
def save_checkpoint(self, val_loss, model, best_ckpt_path):
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), self.path)
torch.save(model.state_dict(), best_ckpt_path)
self.val_loss_min = val_loss
def train(backbone_dir, deal_folder, not_deal_folder,
batch_size, initial_lr=1e-5, max_epochs=100,
best_ckpt_path="best_ckpt.pth", final_ckpt_path="final_ckpt.pth", device="cuda"):
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
train_loader = data_dict["train"]
val_loader = data_dict["val"]
batch_size, initial_lr=1e-5, max_epochs=100, threshold: int = 10,
device="cuda", use_focal_loss=False, balance=True):
best_ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
model = TransClassifier(backbone_dir, device)
if use_focal_loss:
model = TransClassifier(backbone_dir, output_classes=1, device=device)
if balance:
loss_func = FocalLoss(
gamma=2.0,
alpha=0.5,
reduction='mean',
task_type='binary')
else:
loss_func = FocalLoss(
gamma=2.0,
alpha=0.8,
reduction='mean',
task_type='binary')
else:
model = TransClassifier(backbone_dir, output_classes=2, device=device)
loss_func = nn.CrossEntropyLoss()
assert balance == True, "When not using CE loss, balance must be True."
model.to(device)
data_dict = build_dataloader(deal_data_folder=deal_folder, not_deal_data_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, balance=balance)
train_loader = data_dict["train"]
val_loader = data_dict["val"]
test_loader = data_dict["test"]
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
loss_func = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(path=best_ckpt_path)
early_stopping = EarlyStopping(patience=10, delta=0)
history = {"train_loss": [], "val_loss": [], "epoch": []}
for epoch in range(max_epochs):
@@ -83,8 +101,13 @@ def train(backbone_dir, deal_folder, not_deal_folder,
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
if use_focal_loss:
outputs = outputs.squeeze(1)
loss = loss_func(outputs, labels.float())
else:
loss = loss_func(outputs, labels)
optimizer.zero_grad()
@@ -103,26 +126,25 @@ def train(backbone_dir, deal_folder, not_deal_folder,
torch.cuda.empty_cache()
gc.collect()
val_loss = val(val_loader, model, loss_func, tokenizer, device)
history["train_loss"].append(total_loss / len(train_loader))
val_loss = val(val_loader, model, loss_func, tokenizer, device, use_focal_loss)
train_loss_epoch = total_loss / len(train_loader)
history["train_loss"].append(train_loss_epoch)
history["val_loss"].append(val_loss)
history["epoch"].append(epoch+1)
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
early_stopping(val_loss, model)
early_stopping(val_loss, model, best_ckpt_path)
if early_stopping.early_stop:
print("Early stopping")
break
torch.save(model.state_dict(), final_ckpt_path)
print(f"Final model saved to {final_ckpt_path}")
history_df = pd.DataFrame(history)
history_df.to_csv("training_history.csv", index=False)
print("Training history saved to training_history.csv")
history_df.to_csv(f"training_history_threshold_{threshold}.csv", index=False)
print(f"Training history saved to training_history_threshold_{threshold}.csv")
return test_loader
def val(val_loader, model, loss_func, tokenizer, device):
def val(val_loader, model, loss_func, tokenizer, device, use_focal_loss=False):
model.eval()
val_loss = 0.0
with torch.no_grad():
@@ -133,17 +155,169 @@ def val(val_loader, model, loss_func, tokenizer, device):
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
if use_focal_loss:
outputs = outputs.squeeze(1)
loss = loss_func(outputs, labels.float())
else:
loss = loss_func(outputs, labels)
val_loss += loss.item()
del inputs, outputs, labels, loss
gc.collect()
torch.cuda.empty_cache()
return val_loss / len(val_loader)
def test(backbone_dir, test_loader, device, threshold, use_focal_loss=False, balance=True):
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
if use_focal_loss:
model = TransClassifier(backbone_dir, output_classes=1, device=device)
else:
model = TransClassifier(backbone_dir, output_classes=2, device=device)
model.to(device)
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
if os.path.exists(ckpt_path):
model.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Model loaded from {ckpt_path}")
else:
print(f"Warning: {ckpt_path} not found. Using untrained model.")
model.eval()
all_ids = []
all_preds = []
all_probs = []
all_labels = []
pbar = tqdm(test_loader, desc="Testing")
with torch.inference_mode():
for batch_idx, (ids, texts, labels) in enumerate(pbar):
all_ids.extend(ids)
labels = labels.to(device)
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
if use_focal_loss:
outputs = outputs.squeeze(-1) # [B, 1] -> [B]
if use_focal_loss:
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.sigmoid(outputs_float).cpu().numpy().tolist() # [B]
preds = [1 if p >= 0.5 else 0 for p in probs]
else:
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs]
all_preds.extend(preds)
all_probs.extend(probs)
all_labels.extend(labels.cpu().numpy())
# 清理内存
del texts, labels, outputs
torch.cuda.empty_cache()
gc.collect()
# 计算评估指标
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average="weighted")
recall = recall_score(all_labels, all_preds, average="weighted")
f1 = f1_score(all_labels, all_preds, average="weighted")
auc = roc_auc_score(all_labels, all_probs)
cm = confusion_matrix(all_labels, all_preds)
# 打印评估结果
print("\n=== Test Results ===")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"AUC: {auc:.4f}")
cm_df = pd.DataFrame(cm,
index=['Actual Not Deal (0)', 'Actual Deal (1)'],
columns=['Predicted Not Deal (0)', 'Predicted Deal (1)'])
print("\nConfusion Matrix:")
print(cm_df)
precision_per_class = precision_score(all_labels, all_preds, average=None)
recall_per_class = recall_score(all_labels, all_preds, average=None)
f1_per_class = f1_score(all_labels, all_preds, average=None)
print("\n=== Class-wise Metrics ===")
print("Class 0 (Not Deal):")
print(f" Precision: {precision_per_class[0]:.4f}")
print(f" Recall: {recall_per_class[0]:.4f}")
print(f" F1 Score: {f1_per_class[0]:.4f}")
print("\nClass 1 (Deal):")
print(f" Precision: {precision_per_class[1]:.4f}")
print(f" Recall: {recall_per_class[1]:.4f}")
print(f" F1 Score: {f1_per_class[1]:.4f}")
test_results = {
"accuracy": accuracy,
"precision_weighted": precision,
"recall_weighted": recall,
"f1_weighted": f1,
"auc": auc,
"confusion_matrix": cm.tolist(),
"class_0_precision": precision_per_class[0],
"class_0_recall": recall_per_class[0],
"class_0_f1": f1_per_class[0],
"class_1_precision": precision_per_class[1],
"class_1_recall": recall_per_class[1],
"class_1_f1": f1_per_class[1],
"test_samples": len(all_labels)
}
with open(f"test_results_threshold_{threshold}.json", "w", encoding="utf-8") as f:
json.dump(test_results, f, ensure_ascii=False, indent=4)
print(f"\nTest results saved to test_results_threshold_{threshold}.json")
pred_df = pd.DataFrame({
"ids": all_ids,
"predictions": all_preds,
"probability": all_probs,
"true_labels": all_labels
})
pred_df.to_csv(f"test_predictions_threshold_{threshold}.csv", index=False, encoding="utf-8")
if __name__ == "__main__":
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
deal_folder = "deal"
not_deal_folder = "not_deal"
batch_size = 8
backbone_dir = "Qwen3-1.7B"
deal_folder = "filtered_deal"
not_deal_folder = "filtered_not_deal"
batch_size = 4
device = "cuda"
""" threshold = 10
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, device=device, use_focal_loss=False, balance=True)
train(backbone_dir, deal_folder, not_deal_folder, batch_size, device=device)
test(
backbone_dir=backbone_dir,
test_loader=test_loader,
device=device,
threshold=threshold,
use_focal_loss=False,
balance=True
) """
max_threshold = 10
for i in range(3, 9):
print(f"Training with threshold {i}...")
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=i, device=device, use_focal_loss=False, balance=True)
test(
backbone_dir=backbone_dir,
test_loader=test_loader,
device=device,
threshold=i,
use_focal_loss=False,
balance=True
)