Update train.py
This commit is contained in:
472
train.py
472
train.py
@@ -1,149 +1,323 @@
|
|||||||
from data_process import build_dataloader
|
from data_process import build_dataloader
|
||||||
from model import TransClassifier
|
from model import TransClassifier, FocalLoss
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import gc
|
import gc
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
class EarlyStopping:
|
|
||||||
def __init__(self, patience=5, delta=0, path='checkpoint.pt'):
|
class EarlyStopping:
|
||||||
self.patience = patience
|
def __init__(self, patience=5, delta=0):
|
||||||
self.counter = 0
|
self.patience = patience
|
||||||
self.best_score = None
|
self.counter = 0
|
||||||
self.early_stop = False
|
self.best_score = None
|
||||||
self.val_loss_min = np.inf
|
self.early_stop = False
|
||||||
self.delta = delta
|
self.val_loss_min = np.inf
|
||||||
self.path = path
|
self.delta = delta
|
||||||
|
|
||||||
def __call__(self, val_loss, model):
|
def __call__(self, val_loss, model, best_ckpt_path):
|
||||||
score = -val_loss
|
score = -val_loss
|
||||||
|
|
||||||
if self.best_score is None:
|
if self.best_score is None:
|
||||||
self.best_score = score
|
self.best_score = score
|
||||||
self.save_checkpoint(val_loss, model)
|
self.save_checkpoint(val_loss, model, best_ckpt_path)
|
||||||
elif score < self.best_score + self.delta:
|
|
||||||
self.counter += 1
|
elif score < self.best_score + self.delta:
|
||||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
self.counter += 1
|
||||||
if self.counter >= self.patience:
|
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||||
self.early_stop = True
|
if self.counter >= self.patience:
|
||||||
else:
|
self.early_stop = True
|
||||||
self.best_score = score
|
else:
|
||||||
self.save_checkpoint(val_loss, model)
|
self.best_score = score
|
||||||
self.counter = 0
|
self.save_checkpoint(val_loss, model, best_ckpt_path)
|
||||||
|
self.counter = 0
|
||||||
def save_checkpoint(self, val_loss, model):
|
|
||||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
def save_checkpoint(self, val_loss, model, best_ckpt_path):
|
||||||
torch.save(model.state_dict(), self.path)
|
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
||||||
self.val_loss_min = val_loss
|
torch.save(model.state_dict(), best_ckpt_path)
|
||||||
|
self.val_loss_min = val_loss
|
||||||
def train(backbone_dir, deal_folder, not_deal_folder,
|
|
||||||
batch_size, initial_lr=1e-5, max_epochs=100,
|
def train(backbone_dir, deal_folder, not_deal_folder,
|
||||||
best_ckpt_path="best_ckpt.pth", final_ckpt_path="final_ckpt.pth", device="cuda"):
|
batch_size, initial_lr=1e-5, max_epochs=100, threshold: int = 10,
|
||||||
|
device="cuda", use_focal_loss=False, balance=True):
|
||||||
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
|
best_ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
||||||
train_loader = data_dict["train"]
|
|
||||||
val_loader = data_dict["val"]
|
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||||
|
if use_focal_loss:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
model = TransClassifier(backbone_dir, output_classes=1, device=device)
|
||||||
model = TransClassifier(backbone_dir, device)
|
if balance:
|
||||||
model.to(device)
|
loss_func = FocalLoss(
|
||||||
|
gamma=2.0,
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
|
alpha=0.5,
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
reduction='mean',
|
||||||
|
task_type='binary')
|
||||||
loss_func = nn.CrossEntropyLoss()
|
else:
|
||||||
|
loss_func = FocalLoss(
|
||||||
early_stopping = EarlyStopping(path=best_ckpt_path)
|
gamma=2.0,
|
||||||
history = {"train_loss": [], "val_loss": [], "epoch": []}
|
alpha=0.8,
|
||||||
|
reduction='mean',
|
||||||
for epoch in range(max_epochs):
|
task_type='binary')
|
||||||
model.train()
|
else:
|
||||||
total_loss = 0.0
|
model = TransClassifier(backbone_dir, output_classes=2, device=device)
|
||||||
train_steps = 0
|
loss_func = nn.CrossEntropyLoss()
|
||||||
|
assert balance == True, "When not using CE loss, balance must be True."
|
||||||
if epoch == 2:
|
model.to(device)
|
||||||
for param in model.backbone.parameters():
|
|
||||||
param.requires_grad = True
|
data_dict = build_dataloader(deal_data_folder=deal_folder, not_deal_data_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, balance=balance)
|
||||||
print("Unfreeze backbone parameters")
|
train_loader = data_dict["train"]
|
||||||
|
val_loader = data_dict["val"]
|
||||||
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]')
|
test_loader = data_dict["test"]
|
||||||
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
|
||||||
labels = labels.to(device)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
||||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
|
||||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
early_stopping = EarlyStopping(patience=10, delta=0)
|
||||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
history = {"train_loss": [], "val_loss": [], "epoch": []}
|
||||||
outputs = model(inputs)
|
|
||||||
loss = loss_func(outputs, labels)
|
for epoch in range(max_epochs):
|
||||||
|
model.train()
|
||||||
optimizer.zero_grad()
|
total_loss = 0.0
|
||||||
loss.backward()
|
train_steps = 0
|
||||||
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
||||||
optimizer.step()
|
if epoch == 2:
|
||||||
scheduler.step()
|
for param in model.backbone.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
total_loss += loss.item()
|
print("Unfreeze backbone parameters")
|
||||||
train_steps += 1
|
|
||||||
|
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]')
|
||||||
train_loss = total_loss / train_steps
|
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||||
pbar.set_postfix({"train_loss": train_loss})
|
labels = labels.to(device)
|
||||||
|
|
||||||
del texts, labels, outputs, loss
|
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||||
torch.cuda.empty_cache()
|
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||||
gc.collect()
|
|
||||||
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
val_loss = val(val_loader, model, loss_func, tokenizer, device)
|
outputs = model(inputs)
|
||||||
history["train_loss"].append(total_loss / len(train_loader))
|
if use_focal_loss:
|
||||||
history["val_loss"].append(val_loss)
|
outputs = outputs.squeeze(1)
|
||||||
history["epoch"].append(epoch+1)
|
loss = loss_func(outputs, labels.float())
|
||||||
|
else:
|
||||||
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
|
loss = loss_func(outputs, labels)
|
||||||
|
|
||||||
early_stopping(val_loss, model)
|
optimizer.zero_grad()
|
||||||
if early_stopping.early_stop:
|
loss.backward()
|
||||||
print("Early stopping")
|
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
break
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
torch.save(model.state_dict(), final_ckpt_path)
|
|
||||||
print(f"Final model saved to {final_ckpt_path}")
|
total_loss += loss.item()
|
||||||
|
train_steps += 1
|
||||||
history_df = pd.DataFrame(history)
|
|
||||||
history_df.to_csv("training_history.csv", index=False)
|
train_loss = total_loss / train_steps
|
||||||
print("Training history saved to training_history.csv")
|
pbar.set_postfix({"train_loss": train_loss})
|
||||||
|
|
||||||
def val(val_loader, model, loss_func, tokenizer, device):
|
del texts, labels, outputs, loss
|
||||||
model.eval()
|
torch.cuda.empty_cache()
|
||||||
val_loss = 0.0
|
gc.collect()
|
||||||
with torch.no_grad():
|
|
||||||
for batch_idx, (ids, texts, labels) in enumerate(val_loader):
|
val_loss = val(val_loader, model, loss_func, tokenizer, device, use_focal_loss)
|
||||||
labels = labels.to(device)
|
train_loss_epoch = total_loss / len(train_loader)
|
||||||
|
history["train_loss"].append(train_loss_epoch)
|
||||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
history["val_loss"].append(val_loss)
|
||||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
history["epoch"].append(epoch+1)
|
||||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
||||||
outputs = model(inputs)
|
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
|
||||||
loss = loss_func(outputs, labels)
|
|
||||||
|
early_stopping(val_loss, model, best_ckpt_path)
|
||||||
val_loss += loss.item()
|
if early_stopping.early_stop:
|
||||||
return val_loss / len(val_loader)
|
print("Early stopping")
|
||||||
|
break
|
||||||
if __name__ == "__main__":
|
|
||||||
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
|
history_df = pd.DataFrame(history)
|
||||||
deal_folder = "deal"
|
history_df.to_csv(f"training_history_threshold_{threshold}.csv", index=False)
|
||||||
not_deal_folder = "not_deal"
|
print(f"Training history saved to training_history_threshold_{threshold}.csv")
|
||||||
batch_size = 8
|
return test_loader
|
||||||
device = "cuda"
|
|
||||||
|
def val(val_loader, model, loss_func, tokenizer, device, use_focal_loss=False):
|
||||||
|
model.eval()
|
||||||
train(backbone_dir, deal_folder, not_deal_folder, batch_size, device=device)
|
val_loss = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (ids, texts, labels) in enumerate(val_loader):
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||||
|
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||||
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
|
outputs = model(inputs)
|
||||||
|
if use_focal_loss:
|
||||||
|
outputs = outputs.squeeze(1)
|
||||||
|
loss = loss_func(outputs, labels.float())
|
||||||
|
else:
|
||||||
|
loss = loss_func(outputs, labels)
|
||||||
|
|
||||||
|
val_loss += loss.item()
|
||||||
|
|
||||||
|
del inputs, outputs, labels, loss
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return val_loss / len(val_loader)
|
||||||
|
|
||||||
|
def test(backbone_dir, test_loader, device, threshold, use_focal_loss=False, balance=True):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||||
|
|
||||||
|
if use_focal_loss:
|
||||||
|
model = TransClassifier(backbone_dir, output_classes=1, device=device)
|
||||||
|
else:
|
||||||
|
model = TransClassifier(backbone_dir, output_classes=2, device=device)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
||||||
|
if os.path.exists(ckpt_path):
|
||||||
|
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||||
|
print(f"Model loaded from {ckpt_path}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: {ckpt_path} not found. Using untrained model.")
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
all_ids = []
|
||||||
|
all_preds = []
|
||||||
|
all_probs = []
|
||||||
|
all_labels = []
|
||||||
|
|
||||||
|
pbar = tqdm(test_loader, desc="Testing")
|
||||||
|
with torch.inference_mode():
|
||||||
|
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||||
|
all_ids.extend(ids)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||||
|
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
|
outputs = model(inputs)
|
||||||
|
if use_focal_loss:
|
||||||
|
outputs = outputs.squeeze(-1) # [B, 1] -> [B]
|
||||||
|
|
||||||
|
if use_focal_loss:
|
||||||
|
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||||
|
probs = torch.sigmoid(outputs_float).cpu().numpy().tolist() # [B]
|
||||||
|
preds = [1 if p >= 0.5 else 0 for p in probs]
|
||||||
|
else:
|
||||||
|
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||||
|
|
||||||
|
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||||
|
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||||
|
probs = probs.cpu().numpy().tolist()
|
||||||
|
probs = [p[1] for p in probs]
|
||||||
|
|
||||||
|
all_preds.extend(preds)
|
||||||
|
all_probs.extend(probs)
|
||||||
|
all_labels.extend(labels.cpu().numpy())
|
||||||
|
|
||||||
|
# 清理内存
|
||||||
|
del texts, labels, outputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# 计算评估指标
|
||||||
|
accuracy = accuracy_score(all_labels, all_preds)
|
||||||
|
precision = precision_score(all_labels, all_preds, average="weighted")
|
||||||
|
recall = recall_score(all_labels, all_preds, average="weighted")
|
||||||
|
f1 = f1_score(all_labels, all_preds, average="weighted")
|
||||||
|
auc = roc_auc_score(all_labels, all_probs)
|
||||||
|
cm = confusion_matrix(all_labels, all_preds)
|
||||||
|
|
||||||
|
# 打印评估结果
|
||||||
|
print("\n=== Test Results ===")
|
||||||
|
print(f"Accuracy: {accuracy:.4f}")
|
||||||
|
print(f"Precision: {precision:.4f}")
|
||||||
|
print(f"Recall: {recall:.4f}")
|
||||||
|
print(f"F1 Score: {f1:.4f}")
|
||||||
|
print(f"AUC: {auc:.4f}")
|
||||||
|
|
||||||
|
cm_df = pd.DataFrame(cm,
|
||||||
|
index=['Actual Not Deal (0)', 'Actual Deal (1)'],
|
||||||
|
columns=['Predicted Not Deal (0)', 'Predicted Deal (1)'])
|
||||||
|
print("\nConfusion Matrix:")
|
||||||
|
print(cm_df)
|
||||||
|
|
||||||
|
precision_per_class = precision_score(all_labels, all_preds, average=None)
|
||||||
|
recall_per_class = recall_score(all_labels, all_preds, average=None)
|
||||||
|
f1_per_class = f1_score(all_labels, all_preds, average=None)
|
||||||
|
print("\n=== Class-wise Metrics ===")
|
||||||
|
print("Class 0 (Not Deal):")
|
||||||
|
print(f" Precision: {precision_per_class[0]:.4f}")
|
||||||
|
print(f" Recall: {recall_per_class[0]:.4f}")
|
||||||
|
print(f" F1 Score: {f1_per_class[0]:.4f}")
|
||||||
|
print("\nClass 1 (Deal):")
|
||||||
|
print(f" Precision: {precision_per_class[1]:.4f}")
|
||||||
|
print(f" Recall: {recall_per_class[1]:.4f}")
|
||||||
|
print(f" F1 Score: {f1_per_class[1]:.4f}")
|
||||||
|
|
||||||
|
test_results = {
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"precision_weighted": precision,
|
||||||
|
"recall_weighted": recall,
|
||||||
|
"f1_weighted": f1,
|
||||||
|
"auc": auc,
|
||||||
|
"confusion_matrix": cm.tolist(),
|
||||||
|
"class_0_precision": precision_per_class[0],
|
||||||
|
"class_0_recall": recall_per_class[0],
|
||||||
|
"class_0_f1": f1_per_class[0],
|
||||||
|
"class_1_precision": precision_per_class[1],
|
||||||
|
"class_1_recall": recall_per_class[1],
|
||||||
|
"class_1_f1": f1_per_class[1],
|
||||||
|
"test_samples": len(all_labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(f"test_results_threshold_{threshold}.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
||||||
|
print(f"\nTest results saved to test_results_threshold_{threshold}.json")
|
||||||
|
|
||||||
|
pred_df = pd.DataFrame({
|
||||||
|
"ids": all_ids,
|
||||||
|
"predictions": all_preds,
|
||||||
|
"probability": all_probs,
|
||||||
|
"true_labels": all_labels
|
||||||
|
})
|
||||||
|
pred_df.to_csv(f"test_predictions_threshold_{threshold}.csv", index=False, encoding="utf-8")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
backbone_dir = "Qwen3-1.7B"
|
||||||
|
deal_folder = "filtered_deal"
|
||||||
|
not_deal_folder = "filtered_not_deal"
|
||||||
|
batch_size = 4
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
""" threshold = 10
|
||||||
|
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, device=device, use_focal_loss=False, balance=True)
|
||||||
|
|
||||||
|
test(
|
||||||
|
backbone_dir=backbone_dir,
|
||||||
|
test_loader=test_loader,
|
||||||
|
device=device,
|
||||||
|
threshold=threshold,
|
||||||
|
use_focal_loss=False,
|
||||||
|
balance=True
|
||||||
|
) """
|
||||||
|
|
||||||
|
max_threshold = 10
|
||||||
|
for i in range(3, 9):
|
||||||
|
print(f"Training with threshold {i}...")
|
||||||
|
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=i, device=device, use_focal_loss=False, balance=True)
|
||||||
|
|
||||||
|
test(
|
||||||
|
backbone_dir=backbone_dir,
|
||||||
|
test_loader=test_loader,
|
||||||
|
device=device,
|
||||||
|
threshold=i,
|
||||||
|
use_focal_loss=False,
|
||||||
|
balance=True
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user