Update train.py
This commit is contained in:
238
train.py
238
train.py
@@ -1,5 +1,5 @@
|
||||
from data_process import build_dataloader
|
||||
from model import TransClassifier
|
||||
from model import TransClassifier, FocalLoss
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,24 +14,25 @@ from datetime import datetime
|
||||
import gc
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience=5, delta=0, path='checkpoint.pt'):
|
||||
def __init__(self, patience=5, delta=0):
|
||||
self.patience = patience
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = np.inf
|
||||
self.delta = delta
|
||||
self.path = path
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
def __call__(self, val_loss, model, best_ckpt_path):
|
||||
score = -val_loss
|
||||
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.save_checkpoint(val_loss, model, best_ckpt_path)
|
||||
|
||||
elif score < self.best_score + self.delta:
|
||||
self.counter += 1
|
||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
@@ -39,32 +40,49 @@ class EarlyStopping:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.save_checkpoint(val_loss, model, best_ckpt_path)
|
||||
self.counter = 0
|
||||
|
||||
def save_checkpoint(self, val_loss, model):
|
||||
def save_checkpoint(self, val_loss, model, best_ckpt_path):
|
||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
||||
torch.save(model.state_dict(), self.path)
|
||||
torch.save(model.state_dict(), best_ckpt_path)
|
||||
self.val_loss_min = val_loss
|
||||
|
||||
def train(backbone_dir, deal_folder, not_deal_folder,
|
||||
batch_size, initial_lr=1e-5, max_epochs=100,
|
||||
best_ckpt_path="best_ckpt.pth", final_ckpt_path="final_ckpt.pth", device="cuda"):
|
||||
|
||||
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
|
||||
train_loader = data_dict["train"]
|
||||
val_loader = data_dict["val"]
|
||||
batch_size, initial_lr=1e-5, max_epochs=100, threshold: int = 10,
|
||||
device="cuda", use_focal_loss=False, balance=True):
|
||||
best_ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||
model = TransClassifier(backbone_dir, device)
|
||||
if use_focal_loss:
|
||||
model = TransClassifier(backbone_dir, output_classes=1, device=device)
|
||||
if balance:
|
||||
loss_func = FocalLoss(
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
reduction='mean',
|
||||
task_type='binary')
|
||||
else:
|
||||
loss_func = FocalLoss(
|
||||
gamma=2.0,
|
||||
alpha=0.8,
|
||||
reduction='mean',
|
||||
task_type='binary')
|
||||
else:
|
||||
model = TransClassifier(backbone_dir, output_classes=2, device=device)
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
assert balance == True, "When not using CE loss, balance must be True."
|
||||
model.to(device)
|
||||
|
||||
data_dict = build_dataloader(deal_data_folder=deal_folder, not_deal_data_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, balance=balance)
|
||||
train_loader = data_dict["train"]
|
||||
val_loader = data_dict["val"]
|
||||
test_loader = data_dict["test"]
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
||||
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
early_stopping = EarlyStopping(path=best_ckpt_path)
|
||||
early_stopping = EarlyStopping(patience=10, delta=0)
|
||||
history = {"train_loss": [], "val_loss": [], "epoch": []}
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
@@ -83,8 +101,13 @@ def train(backbone_dir, deal_folder, not_deal_folder,
|
||||
|
||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
if use_focal_loss:
|
||||
outputs = outputs.squeeze(1)
|
||||
loss = loss_func(outputs, labels.float())
|
||||
else:
|
||||
loss = loss_func(outputs, labels)
|
||||
|
||||
optimizer.zero_grad()
|
||||
@@ -103,26 +126,25 @@ def train(backbone_dir, deal_folder, not_deal_folder,
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
val_loss = val(val_loader, model, loss_func, tokenizer, device)
|
||||
history["train_loss"].append(total_loss / len(train_loader))
|
||||
val_loss = val(val_loader, model, loss_func, tokenizer, device, use_focal_loss)
|
||||
train_loss_epoch = total_loss / len(train_loader)
|
||||
history["train_loss"].append(train_loss_epoch)
|
||||
history["val_loss"].append(val_loss)
|
||||
history["epoch"].append(epoch+1)
|
||||
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
|
||||
|
||||
early_stopping(val_loss, model)
|
||||
early_stopping(val_loss, model, best_ckpt_path)
|
||||
if early_stopping.early_stop:
|
||||
print("Early stopping")
|
||||
break
|
||||
|
||||
torch.save(model.state_dict(), final_ckpt_path)
|
||||
print(f"Final model saved to {final_ckpt_path}")
|
||||
|
||||
history_df = pd.DataFrame(history)
|
||||
history_df.to_csv("training_history.csv", index=False)
|
||||
print("Training history saved to training_history.csv")
|
||||
history_df.to_csv(f"training_history_threshold_{threshold}.csv", index=False)
|
||||
print(f"Training history saved to training_history_threshold_{threshold}.csv")
|
||||
return test_loader
|
||||
|
||||
def val(val_loader, model, loss_func, tokenizer, device):
|
||||
def val(val_loader, model, loss_func, tokenizer, device, use_focal_loss=False):
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
@@ -133,17 +155,169 @@ def val(val_loader, model, loss_func, tokenizer, device):
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
if use_focal_loss:
|
||||
outputs = outputs.squeeze(1)
|
||||
loss = loss_func(outputs, labels.float())
|
||||
else:
|
||||
loss = loss_func(outputs, labels)
|
||||
|
||||
val_loss += loss.item()
|
||||
|
||||
del inputs, outputs, labels, loss
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return val_loss / len(val_loader)
|
||||
|
||||
def test(backbone_dir, test_loader, device, threshold, use_focal_loss=False, balance=True):
|
||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||
|
||||
if use_focal_loss:
|
||||
model = TransClassifier(backbone_dir, output_classes=1, device=device)
|
||||
else:
|
||||
model = TransClassifier(backbone_dir, output_classes=2, device=device)
|
||||
model.to(device)
|
||||
|
||||
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
||||
if os.path.exists(ckpt_path):
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
print(f"Model loaded from {ckpt_path}")
|
||||
else:
|
||||
print(f"Warning: {ckpt_path} not found. Using untrained model.")
|
||||
|
||||
model.eval()
|
||||
|
||||
all_ids = []
|
||||
all_preds = []
|
||||
all_probs = []
|
||||
all_labels = []
|
||||
|
||||
pbar = tqdm(test_loader, desc="Testing")
|
||||
with torch.inference_mode():
|
||||
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||
all_ids.extend(ids)
|
||||
labels = labels.to(device)
|
||||
|
||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
if use_focal_loss:
|
||||
outputs = outputs.squeeze(-1) # [B, 1] -> [B]
|
||||
|
||||
if use_focal_loss:
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.sigmoid(outputs_float).cpu().numpy().tolist() # [B]
|
||||
preds = [1 if p >= 0.5 else 0 for p in probs]
|
||||
else:
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs]
|
||||
|
||||
all_preds.extend(preds)
|
||||
all_probs.extend(probs)
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
# 清理内存
|
||||
del texts, labels, outputs
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 计算评估指标
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
precision = precision_score(all_labels, all_preds, average="weighted")
|
||||
recall = recall_score(all_labels, all_preds, average="weighted")
|
||||
f1 = f1_score(all_labels, all_preds, average="weighted")
|
||||
auc = roc_auc_score(all_labels, all_probs)
|
||||
cm = confusion_matrix(all_labels, all_preds)
|
||||
|
||||
# 打印评估结果
|
||||
print("\n=== Test Results ===")
|
||||
print(f"Accuracy: {accuracy:.4f}")
|
||||
print(f"Precision: {precision:.4f}")
|
||||
print(f"Recall: {recall:.4f}")
|
||||
print(f"F1 Score: {f1:.4f}")
|
||||
print(f"AUC: {auc:.4f}")
|
||||
|
||||
cm_df = pd.DataFrame(cm,
|
||||
index=['Actual Not Deal (0)', 'Actual Deal (1)'],
|
||||
columns=['Predicted Not Deal (0)', 'Predicted Deal (1)'])
|
||||
print("\nConfusion Matrix:")
|
||||
print(cm_df)
|
||||
|
||||
precision_per_class = precision_score(all_labels, all_preds, average=None)
|
||||
recall_per_class = recall_score(all_labels, all_preds, average=None)
|
||||
f1_per_class = f1_score(all_labels, all_preds, average=None)
|
||||
print("\n=== Class-wise Metrics ===")
|
||||
print("Class 0 (Not Deal):")
|
||||
print(f" Precision: {precision_per_class[0]:.4f}")
|
||||
print(f" Recall: {recall_per_class[0]:.4f}")
|
||||
print(f" F1 Score: {f1_per_class[0]:.4f}")
|
||||
print("\nClass 1 (Deal):")
|
||||
print(f" Precision: {precision_per_class[1]:.4f}")
|
||||
print(f" Recall: {recall_per_class[1]:.4f}")
|
||||
print(f" F1 Score: {f1_per_class[1]:.4f}")
|
||||
|
||||
test_results = {
|
||||
"accuracy": accuracy,
|
||||
"precision_weighted": precision,
|
||||
"recall_weighted": recall,
|
||||
"f1_weighted": f1,
|
||||
"auc": auc,
|
||||
"confusion_matrix": cm.tolist(),
|
||||
"class_0_precision": precision_per_class[0],
|
||||
"class_0_recall": recall_per_class[0],
|
||||
"class_0_f1": f1_per_class[0],
|
||||
"class_1_precision": precision_per_class[1],
|
||||
"class_1_recall": recall_per_class[1],
|
||||
"class_1_f1": f1_per_class[1],
|
||||
"test_samples": len(all_labels)
|
||||
}
|
||||
|
||||
with open(f"test_results_threshold_{threshold}.json", "w", encoding="utf-8") as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
||||
print(f"\nTest results saved to test_results_threshold_{threshold}.json")
|
||||
|
||||
pred_df = pd.DataFrame({
|
||||
"ids": all_ids,
|
||||
"predictions": all_preds,
|
||||
"probability": all_probs,
|
||||
"true_labels": all_labels
|
||||
})
|
||||
pred_df.to_csv(f"test_predictions_threshold_{threshold}.csv", index=False, encoding="utf-8")
|
||||
|
||||
if __name__ == "__main__":
|
||||
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
|
||||
deal_folder = "deal"
|
||||
not_deal_folder = "not_deal"
|
||||
batch_size = 8
|
||||
backbone_dir = "Qwen3-1.7B"
|
||||
deal_folder = "filtered_deal"
|
||||
not_deal_folder = "filtered_not_deal"
|
||||
batch_size = 4
|
||||
device = "cuda"
|
||||
|
||||
""" threshold = 10
|
||||
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, device=device, use_focal_loss=False, balance=True)
|
||||
|
||||
train(backbone_dir, deal_folder, not_deal_folder, batch_size, device=device)
|
||||
test(
|
||||
backbone_dir=backbone_dir,
|
||||
test_loader=test_loader,
|
||||
device=device,
|
||||
threshold=threshold,
|
||||
use_focal_loss=False,
|
||||
balance=True
|
||||
) """
|
||||
|
||||
max_threshold = 10
|
||||
for i in range(3, 9):
|
||||
print(f"Training with threshold {i}...")
|
||||
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=i, device=device, use_focal_loss=False, balance=True)
|
||||
|
||||
test(
|
||||
backbone_dir=backbone_dir,
|
||||
test_loader=test_loader,
|
||||
device=device,
|
||||
threshold=i,
|
||||
use_focal_loss=False,
|
||||
balance=True
|
||||
)
|
||||
Reference in New Issue
Block a user