diff --git a/train.py b/train.py index 78b91f6..3d31f47 100644 --- a/train.py +++ b/train.py @@ -1,149 +1,323 @@ -from data_process import build_dataloader -from model import TransClassifier - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import AutoTokenizer - -import pandas as pd -import numpy as np -import os -import json -from datetime import datetime -import gc -from tqdm import tqdm -import warnings -warnings.filterwarnings("ignore") - -class EarlyStopping: - def __init__(self, patience=5, delta=0, path='checkpoint.pt'): - self.patience = patience - self.counter = 0 - self.best_score = None - self.early_stop = False - self.val_loss_min = np.inf - self.delta = delta - self.path = path - - def __call__(self, val_loss, model): - score = -val_loss - - if self.best_score is None: - self.best_score = score - self.save_checkpoint(val_loss, model) - elif score < self.best_score + self.delta: - self.counter += 1 - print(f'EarlyStopping counter: {self.counter} out of {self.patience}') - if self.counter >= self.patience: - self.early_stop = True - else: - self.best_score = score - self.save_checkpoint(val_loss, model) - self.counter = 0 - - def save_checkpoint(self, val_loss, model): - print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') - torch.save(model.state_dict(), self.path) - self.val_loss_min = val_loss - -def train(backbone_dir, deal_folder, not_deal_folder, - batch_size, initial_lr=1e-5, max_epochs=100, - best_ckpt_path="best_ckpt.pth", final_ckpt_path="final_ckpt.pth", device="cuda"): - - data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size) - train_loader = data_dict["train"] - val_loader = data_dict["val"] - - tokenizer = AutoTokenizer.from_pretrained(backbone_dir) - model = TransClassifier(backbone_dir, device) - model.to(device) - - optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) - - loss_func = nn.CrossEntropyLoss() - - early_stopping = EarlyStopping(path=best_ckpt_path) - history = {"train_loss": [], "val_loss": [], "epoch": []} - - for epoch in range(max_epochs): - model.train() - total_loss = 0.0 - train_steps = 0 - - if epoch == 2: - for param in model.backbone.parameters(): - param.requires_grad = True - print("Unfreeze backbone parameters") - - pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]') - for batch_idx, (ids, texts, labels) in enumerate(pbar): - labels = labels.to(device) - - texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) - inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - outputs = model(inputs) - loss = loss_func(outputs, labels) - - optimizer.zero_grad() - loss.backward() - nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() - scheduler.step() - - total_loss += loss.item() - train_steps += 1 - - train_loss = total_loss / train_steps - pbar.set_postfix({"train_loss": train_loss}) - - del texts, labels, outputs, loss - torch.cuda.empty_cache() - gc.collect() - - val_loss = val(val_loader, model, loss_func, tokenizer, device) - history["train_loss"].append(total_loss / len(train_loader)) - history["val_loss"].append(val_loss) - history["epoch"].append(epoch+1) - - print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}") - - early_stopping(val_loss, model) - if early_stopping.early_stop: - print("Early stopping") - break - - torch.save(model.state_dict(), final_ckpt_path) - print(f"Final model saved to {final_ckpt_path}") - - history_df = pd.DataFrame(history) - history_df.to_csv("training_history.csv", index=False) - print("Training history saved to training_history.csv") - -def val(val_loader, model, loss_func, tokenizer, device): - model.eval() - val_loss = 0.0 - with torch.no_grad(): - for batch_idx, (ids, texts, labels) in enumerate(val_loader): - labels = labels.to(device) - - texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) - inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - outputs = model(inputs) - loss = loss_func(outputs, labels) - - val_loss += loss.item() - return val_loss / len(val_loader) - -if __name__ == "__main__": - backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B" - deal_folder = "deal" - not_deal_folder = "not_deal" - batch_size = 8 - device = "cuda" - - - train(backbone_dir, deal_folder, not_deal_folder, batch_size, device=device) \ No newline at end of file +from data_process import build_dataloader +from model import TransClassifier, FocalLoss + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoTokenizer + +import pandas as pd +import numpy as np +import os +import json +from datetime import datetime +import gc +from tqdm import tqdm +import warnings +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score +warnings.filterwarnings("ignore") + +class EarlyStopping: + def __init__(self, patience=5, delta=0): + self.patience = patience + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.inf + self.delta = delta + + def __call__(self, val_loss, model, best_ckpt_path): + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model, best_ckpt_path) + + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model, best_ckpt_path) + self.counter = 0 + + def save_checkpoint(self, val_loss, model, best_ckpt_path): + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') + torch.save(model.state_dict(), best_ckpt_path) + self.val_loss_min = val_loss + +def train(backbone_dir, deal_folder, not_deal_folder, + batch_size, initial_lr=1e-5, max_epochs=100, threshold: int = 10, + device="cuda", use_focal_loss=False, balance=True): + best_ckpt_path = f"best_ckpt_threshold_{threshold}.pth" + + tokenizer = AutoTokenizer.from_pretrained(backbone_dir) + if use_focal_loss: + model = TransClassifier(backbone_dir, output_classes=1, device=device) + if balance: + loss_func = FocalLoss( + gamma=2.0, + alpha=0.5, + reduction='mean', + task_type='binary') + else: + loss_func = FocalLoss( + gamma=2.0, + alpha=0.8, + reduction='mean', + task_type='binary') + else: + model = TransClassifier(backbone_dir, output_classes=2, device=device) + loss_func = nn.CrossEntropyLoss() + assert balance == True, "When not using CE loss, balance must be True." + model.to(device) + + data_dict = build_dataloader(deal_data_folder=deal_folder, not_deal_data_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, balance=balance) + train_loader = data_dict["train"] + val_loader = data_dict["val"] + test_loader = data_dict["test"] + + optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + + early_stopping = EarlyStopping(patience=10, delta=0) + history = {"train_loss": [], "val_loss": [], "epoch": []} + + for epoch in range(max_epochs): + model.train() + total_loss = 0.0 + train_steps = 0 + + if epoch == 2: + for param in model.backbone.parameters(): + param.requires_grad = True + print("Unfreeze backbone parameters") + + pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]') + for batch_idx, (ids, texts, labels) in enumerate(pbar): + labels = labels.to(device) + + texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) + inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) + + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(inputs) + if use_focal_loss: + outputs = outputs.squeeze(1) + loss = loss_func(outputs, labels.float()) + else: + loss = loss_func(outputs, labels) + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + scheduler.step() + + total_loss += loss.item() + train_steps += 1 + + train_loss = total_loss / train_steps + pbar.set_postfix({"train_loss": train_loss}) + + del texts, labels, outputs, loss + torch.cuda.empty_cache() + gc.collect() + + val_loss = val(val_loader, model, loss_func, tokenizer, device, use_focal_loss) + train_loss_epoch = total_loss / len(train_loader) + history["train_loss"].append(train_loss_epoch) + history["val_loss"].append(val_loss) + history["epoch"].append(epoch+1) + + print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}") + + early_stopping(val_loss, model, best_ckpt_path) + if early_stopping.early_stop: + print("Early stopping") + break + + history_df = pd.DataFrame(history) + history_df.to_csv(f"training_history_threshold_{threshold}.csv", index=False) + print(f"Training history saved to training_history_threshold_{threshold}.csv") + return test_loader + +def val(val_loader, model, loss_func, tokenizer, device, use_focal_loss=False): + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for batch_idx, (ids, texts, labels) in enumerate(val_loader): + labels = labels.to(device) + + texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) + inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(inputs) + if use_focal_loss: + outputs = outputs.squeeze(1) + loss = loss_func(outputs, labels.float()) + else: + loss = loss_func(outputs, labels) + + val_loss += loss.item() + + del inputs, outputs, labels, loss + gc.collect() + torch.cuda.empty_cache() + return val_loss / len(val_loader) + +def test(backbone_dir, test_loader, device, threshold, use_focal_loss=False, balance=True): + tokenizer = AutoTokenizer.from_pretrained(backbone_dir) + + if use_focal_loss: + model = TransClassifier(backbone_dir, output_classes=1, device=device) + else: + model = TransClassifier(backbone_dir, output_classes=2, device=device) + model.to(device) + + ckpt_path = f"best_ckpt_threshold_{threshold}.pth" + if os.path.exists(ckpt_path): + model.load_state_dict(torch.load(ckpt_path, map_location=device)) + print(f"Model loaded from {ckpt_path}") + else: + print(f"Warning: {ckpt_path} not found. Using untrained model.") + + model.eval() + + all_ids = [] + all_preds = [] + all_probs = [] + all_labels = [] + + pbar = tqdm(test_loader, desc="Testing") + with torch.inference_mode(): + for batch_idx, (ids, texts, labels) in enumerate(pbar): + all_ids.extend(ids) + labels = labels.to(device) + + texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False) + inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device) + + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(inputs) + if use_focal_loss: + outputs = outputs.squeeze(-1) # [B, 1] -> [B] + + if use_focal_loss: + outputs_float = outputs.float() # 转换为 float32 避免精度问题 + probs = torch.sigmoid(outputs_float).cpu().numpy().tolist() # [B] + preds = [1 if p >= 0.5 else 0 for p in probs] + else: + preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist() + + outputs_float = outputs.float() # 转换为 float32 避免精度问题 + probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] + probs = probs.cpu().numpy().tolist() + probs = [p[1] for p in probs] + + all_preds.extend(preds) + all_probs.extend(probs) + all_labels.extend(labels.cpu().numpy()) + + # 清理内存 + del texts, labels, outputs + torch.cuda.empty_cache() + gc.collect() + + # 计算评估指标 + accuracy = accuracy_score(all_labels, all_preds) + precision = precision_score(all_labels, all_preds, average="weighted") + recall = recall_score(all_labels, all_preds, average="weighted") + f1 = f1_score(all_labels, all_preds, average="weighted") + auc = roc_auc_score(all_labels, all_probs) + cm = confusion_matrix(all_labels, all_preds) + + # 打印评估结果 + print("\n=== Test Results ===") + print(f"Accuracy: {accuracy:.4f}") + print(f"Precision: {precision:.4f}") + print(f"Recall: {recall:.4f}") + print(f"F1 Score: {f1:.4f}") + print(f"AUC: {auc:.4f}") + + cm_df = pd.DataFrame(cm, + index=['Actual Not Deal (0)', 'Actual Deal (1)'], + columns=['Predicted Not Deal (0)', 'Predicted Deal (1)']) + print("\nConfusion Matrix:") + print(cm_df) + + precision_per_class = precision_score(all_labels, all_preds, average=None) + recall_per_class = recall_score(all_labels, all_preds, average=None) + f1_per_class = f1_score(all_labels, all_preds, average=None) + print("\n=== Class-wise Metrics ===") + print("Class 0 (Not Deal):") + print(f" Precision: {precision_per_class[0]:.4f}") + print(f" Recall: {recall_per_class[0]:.4f}") + print(f" F1 Score: {f1_per_class[0]:.4f}") + print("\nClass 1 (Deal):") + print(f" Precision: {precision_per_class[1]:.4f}") + print(f" Recall: {recall_per_class[1]:.4f}") + print(f" F1 Score: {f1_per_class[1]:.4f}") + + test_results = { + "accuracy": accuracy, + "precision_weighted": precision, + "recall_weighted": recall, + "f1_weighted": f1, + "auc": auc, + "confusion_matrix": cm.tolist(), + "class_0_precision": precision_per_class[0], + "class_0_recall": recall_per_class[0], + "class_0_f1": f1_per_class[0], + "class_1_precision": precision_per_class[1], + "class_1_recall": recall_per_class[1], + "class_1_f1": f1_per_class[1], + "test_samples": len(all_labels) + } + + with open(f"test_results_threshold_{threshold}.json", "w", encoding="utf-8") as f: + json.dump(test_results, f, ensure_ascii=False, indent=4) + print(f"\nTest results saved to test_results_threshold_{threshold}.json") + + pred_df = pd.DataFrame({ + "ids": all_ids, + "predictions": all_preds, + "probability": all_probs, + "true_labels": all_labels + }) + pred_df.to_csv(f"test_predictions_threshold_{threshold}.csv", index=False, encoding="utf-8") + +if __name__ == "__main__": + backbone_dir = "Qwen3-1.7B" + deal_folder = "filtered_deal" + not_deal_folder = "filtered_not_deal" + batch_size = 4 + device = "cuda" + + """ threshold = 10 + test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, device=device, use_focal_loss=False, balance=True) + + test( + backbone_dir=backbone_dir, + test_loader=test_loader, + device=device, + threshold=threshold, + use_focal_loss=False, + balance=True + ) """ + + max_threshold = 10 + for i in range(3, 9): + print(f"Training with threshold {i}...") + test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=i, device=device, use_focal_loss=False, balance=True) + + test( + backbone_dir=backbone_dir, + test_loader=test_loader, + device=device, + threshold=i, + use_focal_loss=False, + balance=True + ) \ No newline at end of file