Upload files to "/"
This commit is contained in:
201
inference.py
Normal file
201
inference.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
from model import TransClassifier
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from data_process import extract_json_data, Formatter
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
valid_keys = [
|
||||||
|
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
|
||||||
|
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
||||||
|
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
||||||
|
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
||||||
|
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
||||||
|
]
|
||||||
|
ch_valid_keys = [
|
||||||
|
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
||||||
|
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
||||||
|
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
||||||
|
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
||||||
|
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
||||||
|
]
|
||||||
|
all_keys = valid_keys + ["session_id", "label"]
|
||||||
|
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
||||||
|
d1_keys = valid_keys[:5]
|
||||||
|
d2_keys = valid_keys[5:10]
|
||||||
|
d3_keys = valid_keys[10:15]
|
||||||
|
d4_keys = valid_keys[15:19]
|
||||||
|
d5_keys = valid_keys[19:23]
|
||||||
|
|
||||||
|
class InferenceEngine:
|
||||||
|
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
|
||||||
|
self.backbone_dir = backbone_dir
|
||||||
|
self.ckpt_path = ckpt_path
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# 加载 tokenizer
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||||
|
print(f"Tokenizer loaded from {backbone_dir}")
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
self.model = TransClassifier(backbone_dir, device)
|
||||||
|
self.model.to(device)
|
||||||
|
if self.ckpt_path:
|
||||||
|
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||||
|
print(f"Model loaded from {ckpt_path}")
|
||||||
|
else:
|
||||||
|
print("Warning: No checkpoint path provided. Using untrained model.")
|
||||||
|
self.model.eval()
|
||||||
|
print("Inference engine initialized successfully.")
|
||||||
|
|
||||||
|
self.formatter = Formatter(en2ch)
|
||||||
|
|
||||||
|
def inference_batch(self, json_list: List[str]) -> dict:
|
||||||
|
"""
|
||||||
|
批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
||||||
|
请注意Json文件中的词条数必须大于等于10.
|
||||||
|
"""
|
||||||
|
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
||||||
|
id2feature = extract_json_data(json_list) # id2feature
|
||||||
|
|
||||||
|
message_list = []
|
||||||
|
for id, feature in id2feature.items():
|
||||||
|
messages = self.formatter.get_llm_prompt(feature)
|
||||||
|
message_list.append(messages)
|
||||||
|
|
||||||
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
|
message_list,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False
|
||||||
|
)
|
||||||
|
model_inputs = self.tokenizer(
|
||||||
|
inputs,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=2048,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||||
|
outputs = self.model(model_inputs)
|
||||||
|
|
||||||
|
# 1. 计算分类标签(argmax)
|
||||||
|
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||||
|
|
||||||
|
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||||
|
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||||
|
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||||
|
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||||
|
probs = probs.cpu().numpy().tolist()
|
||||||
|
|
||||||
|
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表
|
||||||
|
return {"labels": preds, "probs": probs}
|
||||||
|
|
||||||
|
def inference_sample(self, json_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
|
||||||
|
请注意Json文件中的词条数必须大于等于10.
|
||||||
|
"""
|
||||||
|
return self.inference_batch([json_path])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 配置参数
|
||||||
|
backbone_dir = "Qwen3-1.7B"
|
||||||
|
ckpt_path = "best_ckpt.pth"
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
||||||
|
|
||||||
|
from data_process import extract_json_files
|
||||||
|
import random
|
||||||
|
|
||||||
|
# 获取成交和未成交的json文件路径
|
||||||
|
deal_files = extract_json_files("deal")
|
||||||
|
not_deal_files = extract_json_files("not_deal")
|
||||||
|
|
||||||
|
def filter_json_files_by_key_count(files: List[str], min_keys: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
过滤出JSON文件中字典键数量大于等于指定数量的文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: JSON文件路径列表
|
||||||
|
min_keys: 最小键数量要求,默认为10
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
符合条件的文件路径列表
|
||||||
|
"""
|
||||||
|
valid_files = []
|
||||||
|
|
||||||
|
for file_path in files:
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# 检查是否为字典且键数量是否符合要求
|
||||||
|
if isinstance(data, dict) and len(data) >= min_keys:
|
||||||
|
valid_files.append(file_path)
|
||||||
|
else:
|
||||||
|
print(f"跳过文件 {os.path.basename(file_path)}: 键数量不足 ({len(data)} < {min_keys})")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取文件 {file_path} 时出错: {e}")
|
||||||
|
|
||||||
|
return valid_files
|
||||||
|
|
||||||
|
deal_files_filtered = filter_json_files_by_key_count(deal_files, min_keys=10)
|
||||||
|
not_deal_files_filtered = filter_json_files_by_key_count(not_deal_files, min_keys=10)
|
||||||
|
|
||||||
|
num_samples = 8
|
||||||
|
|
||||||
|
# 计算每类需要选取的数量
|
||||||
|
num_deal_needed = min(4, len(deal_files_filtered)) # 最多选4个成交文件
|
||||||
|
num_not_deal_needed = min(4, len(not_deal_files_filtered)) # 最多选4个未成交文件
|
||||||
|
|
||||||
|
# 如果某类文件不足,从另一类补足
|
||||||
|
if num_deal_needed + num_not_deal_needed < num_samples:
|
||||||
|
if len(deal_files_filtered) > num_deal_needed:
|
||||||
|
num_deal_needed = min(num_samples, len(deal_files_filtered))
|
||||||
|
elif len(not_deal_files_filtered) > num_not_deal_needed:
|
||||||
|
num_not_deal_needed = min(num_samples, len(not_deal_files_filtered))
|
||||||
|
|
||||||
|
# 随机选取文件
|
||||||
|
selected_deal_files = random.sample(deal_files_filtered, min(num_deal_needed, len(deal_files_filtered))) if deal_files_filtered else []
|
||||||
|
selected_not_deal_files = random.sample(not_deal_files_filtered, min(num_not_deal_needed, len(not_deal_files_filtered))) if not_deal_files_filtered else []
|
||||||
|
|
||||||
|
# 合并选中的文件
|
||||||
|
selected_files = selected_deal_files + selected_not_deal_files
|
||||||
|
|
||||||
|
# 如果总数不足8个,尝试从原始文件中随机选取补足
|
||||||
|
if len(selected_files) < num_samples:
|
||||||
|
all_files = deal_files + not_deal_files
|
||||||
|
# 排除已选的文件
|
||||||
|
remaining_files = [f for f in all_files if f not in selected_files]
|
||||||
|
additional_needed = num_samples - len(selected_files)
|
||||||
|
if remaining_files:
|
||||||
|
additional_files = random.sample(remaining_files, min(additional_needed, len(remaining_files)))
|
||||||
|
selected_files.extend(additional_files)
|
||||||
|
|
||||||
|
true_labels = []
|
||||||
|
for i, file_path in enumerate(selected_files):
|
||||||
|
folder_type = "未成交" if "not_deal" in file_path else "成交"
|
||||||
|
true_labels.append(folder_type)
|
||||||
|
|
||||||
|
# 使用inference_batch接口进行批量推理
|
||||||
|
if selected_files:
|
||||||
|
print("\n开始批量推理...")
|
||||||
|
try:
|
||||||
|
batch_result = engine.inference_batch(selected_files)
|
||||||
|
print(batch_result)
|
||||||
|
print(true_labels)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"推理过程中出错: {e}")
|
||||||
|
else:
|
||||||
|
print("未找到符合条件的文件进行推理")
|
||||||
|
|
||||||
|
print("\n推理端口测试完成!")
|
||||||
14
statis_main.py
Normal file
14
statis_main.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from data_process import StatisticData, Outputer
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
deal_folder = "deal"
|
||||||
|
not_deal_folder = "not_deal"
|
||||||
|
|
||||||
|
deal_statistic = StatisticData(deal_folder)
|
||||||
|
deal_data = deal_statistic.main()
|
||||||
|
not_deal_statistic = StatisticData(not_deal_folder)
|
||||||
|
not_deal_data= not_deal_statistic.main()
|
||||||
|
|
||||||
|
outputer = Outputer(deal_data, not_deal_data)
|
||||||
|
outputer.visualize_priority()
|
||||||
|
outputer.save_key2counter_excel()
|
||||||
154
test.py
Normal file
154
test.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
from data_process import build_dataloader
|
||||||
|
from model import TransClassifier
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
import gc
|
||||||
|
from tqdm import tqdm
|
||||||
|
import warnings
|
||||||
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
def test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path="best_ckpt.pth", device="cuda"):
|
||||||
|
"""
|
||||||
|
测试模型在测试集上的表现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backbone_dir: 预训练模型目录
|
||||||
|
deal_folder: 成交数据文件夹
|
||||||
|
not_deal_folder: 非成交数据文件夹
|
||||||
|
batch_size: 批量大小
|
||||||
|
ckpt_path: 模型 checkpoint 路径
|
||||||
|
device: 运行设备
|
||||||
|
"""
|
||||||
|
# 加载测试数据
|
||||||
|
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
|
||||||
|
test_loader = data_dict["test"]
|
||||||
|
print(f"Test data loaded successfully. Test samples: {len(test_loader.dataset)}")
|
||||||
|
|
||||||
|
# 加载 tokenizer 和模型
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||||
|
model = TransClassifier(backbone_dir, device)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# 加载训练好的模型权重
|
||||||
|
if os.path.exists(ckpt_path):
|
||||||
|
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||||
|
print(f"Model loaded from {ckpt_path}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: {ckpt_path} not found. Using untrained model.")
|
||||||
|
|
||||||
|
# 测试模型
|
||||||
|
model.eval()
|
||||||
|
all_ids = []
|
||||||
|
all_preds = []
|
||||||
|
all_labels = []
|
||||||
|
test_loss = 0.0
|
||||||
|
|
||||||
|
loss_func = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
pbar = tqdm(test_loader, desc="Testing")
|
||||||
|
with torch.inference_mode():
|
||||||
|
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||||
|
all_ids.extend(ids)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
# 处理输入数据
|
||||||
|
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||||
|
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = loss_func(outputs, labels)
|
||||||
|
|
||||||
|
test_loss += loss.item()
|
||||||
|
|
||||||
|
# 计算预测结果
|
||||||
|
preds = torch.argmax(outputs, dim=1).cpu().numpy()
|
||||||
|
all_preds.extend(preds)
|
||||||
|
all_labels.extend(labels.cpu().numpy())
|
||||||
|
|
||||||
|
# 清理内存
|
||||||
|
del texts, labels, outputs, loss
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# 计算评估指标
|
||||||
|
avg_loss = test_loss / len(test_loader)
|
||||||
|
accuracy = accuracy_score(all_labels, all_preds)
|
||||||
|
precision = precision_score(all_labels, all_preds, average="weighted")
|
||||||
|
recall = recall_score(all_labels, all_preds, average="weighted")
|
||||||
|
f1 = f1_score(all_labels, all_preds, average="weighted")
|
||||||
|
cm = confusion_matrix(all_labels, all_preds)
|
||||||
|
|
||||||
|
# 打印评估结果
|
||||||
|
print("\n=== Test Results ===")
|
||||||
|
print(f"Average Loss: {avg_loss:.4f}")
|
||||||
|
print(f"Accuracy: {accuracy:.4f}")
|
||||||
|
print(f"Precision: {precision:.4f}")
|
||||||
|
print(f"Recall: {recall:.4f}")
|
||||||
|
print(f"F1 Score: {f1:.4f}")
|
||||||
|
print("\nConfusion Matrix:")
|
||||||
|
print(cm)
|
||||||
|
print("\n=== Class-wise Metrics ===")
|
||||||
|
print("Class 0 (Not Deal):")
|
||||||
|
print(f" Precision: {precision_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||||
|
print(f" Recall: {recall_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||||
|
print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||||
|
print("\nClass 1 (Deal):")
|
||||||
|
print(f" Precision: {precision_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||||
|
print(f" Recall: {recall_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||||
|
print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||||
|
|
||||||
|
# 保存测试结果
|
||||||
|
test_results = {
|
||||||
|
"average_loss": avg_loss,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1_score": f1,
|
||||||
|
"confusion_matrix": cm.tolist(),
|
||||||
|
"class_0_precision": precision_score(all_labels, all_preds, average=None)[0],
|
||||||
|
"class_0_recall": recall_score(all_labels, all_preds, average=None)[0],
|
||||||
|
"class_0_f1": f1_score(all_labels, all_preds, average=None)[0],
|
||||||
|
"class_1_precision": precision_score(all_labels, all_preds, average=None)[1],
|
||||||
|
"class_1_recall": recall_score(all_labels, all_preds, average=None)[1],
|
||||||
|
"class_1_f1": f1_score(all_labels, all_preds, average=None)[1],
|
||||||
|
"test_samples": len(all_labels),
|
||||||
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存预测结果
|
||||||
|
pred_results = {
|
||||||
|
"ids": all_ids,
|
||||||
|
"predictions": all_preds,
|
||||||
|
"true_labels": all_labels
|
||||||
|
}
|
||||||
|
pred_df = pd.DataFrame(pred_results)
|
||||||
|
pred_df.to_csv("test_predictions.csv", index=False, encoding="utf-8")
|
||||||
|
|
||||||
|
# 保存为 JSON 文件
|
||||||
|
with open("test_results.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(test_results, f, ensure_ascii=False, indent=2)
|
||||||
|
print("\nTest results saved to test_results.json")
|
||||||
|
return test_results
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 配置参数
|
||||||
|
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
|
||||||
|
deal_folder = "deal"
|
||||||
|
not_deal_folder = "not_deal"
|
||||||
|
batch_size = 8
|
||||||
|
ckpt_path = "best_ckpt.pth"
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
# 运行测试
|
||||||
|
test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path, device)
|
||||||
149
train.py
Normal file
149
train.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
from data_process import build_dataloader
|
||||||
|
from model import TransClassifier
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
import gc
|
||||||
|
from tqdm import tqdm
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
def __init__(self, patience=5, delta=0, path='checkpoint.pt'):
|
||||||
|
self.patience = patience
|
||||||
|
self.counter = 0
|
||||||
|
self.best_score = None
|
||||||
|
self.early_stop = False
|
||||||
|
self.val_loss_min = np.inf
|
||||||
|
self.delta = delta
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
def __call__(self, val_loss, model):
|
||||||
|
score = -val_loss
|
||||||
|
|
||||||
|
if self.best_score is None:
|
||||||
|
self.best_score = score
|
||||||
|
self.save_checkpoint(val_loss, model)
|
||||||
|
elif score < self.best_score + self.delta:
|
||||||
|
self.counter += 1
|
||||||
|
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.early_stop = True
|
||||||
|
else:
|
||||||
|
self.best_score = score
|
||||||
|
self.save_checkpoint(val_loss, model)
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def save_checkpoint(self, val_loss, model):
|
||||||
|
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
||||||
|
torch.save(model.state_dict(), self.path)
|
||||||
|
self.val_loss_min = val_loss
|
||||||
|
|
||||||
|
def train(backbone_dir, deal_folder, not_deal_folder,
|
||||||
|
batch_size, initial_lr=1e-5, max_epochs=100,
|
||||||
|
best_ckpt_path="best_ckpt.pth", final_ckpt_path="final_ckpt.pth", device="cuda"):
|
||||||
|
|
||||||
|
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
|
||||||
|
train_loader = data_dict["train"]
|
||||||
|
val_loader = data_dict["val"]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||||
|
model = TransClassifier(backbone_dir, device)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
||||||
|
|
||||||
|
loss_func = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
early_stopping = EarlyStopping(path=best_ckpt_path)
|
||||||
|
history = {"train_loss": [], "val_loss": [], "epoch": []}
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
train_steps = 0
|
||||||
|
|
||||||
|
if epoch == 2:
|
||||||
|
for param in model.backbone.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
print("Unfreeze backbone parameters")
|
||||||
|
|
||||||
|
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]')
|
||||||
|
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||||
|
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||||
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = loss_func(outputs, labels)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
train_steps += 1
|
||||||
|
|
||||||
|
train_loss = total_loss / train_steps
|
||||||
|
pbar.set_postfix({"train_loss": train_loss})
|
||||||
|
|
||||||
|
del texts, labels, outputs, loss
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
val_loss = val(val_loader, model, loss_func, tokenizer, device)
|
||||||
|
history["train_loss"].append(total_loss / len(train_loader))
|
||||||
|
history["val_loss"].append(val_loss)
|
||||||
|
history["epoch"].append(epoch+1)
|
||||||
|
|
||||||
|
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
|
||||||
|
|
||||||
|
early_stopping(val_loss, model)
|
||||||
|
if early_stopping.early_stop:
|
||||||
|
print("Early stopping")
|
||||||
|
break
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), final_ckpt_path)
|
||||||
|
print(f"Final model saved to {final_ckpt_path}")
|
||||||
|
|
||||||
|
history_df = pd.DataFrame(history)
|
||||||
|
history_df.to_csv("training_history.csv", index=False)
|
||||||
|
print("Training history saved to training_history.csv")
|
||||||
|
|
||||||
|
def val(val_loader, model, loss_func, tokenizer, device):
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (ids, texts, labels) in enumerate(val_loader):
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||||
|
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||||
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = loss_func(outputs, labels)
|
||||||
|
|
||||||
|
val_loss += loss.item()
|
||||||
|
return val_loss / len(val_loader)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
|
||||||
|
deal_folder = "deal"
|
||||||
|
not_deal_folder = "not_deal"
|
||||||
|
batch_size = 8
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
train(backbone_dir, deal_folder, not_deal_folder, batch_size, device=device)
|
||||||
Reference in New Issue
Block a user