Merge commit '2c72c000cd4d81febd83e3049c82a9d2174fd41c'
This commit is contained in:
@@ -9,10 +9,9 @@
|
||||
--model/
|
||||
--__init__.py
|
||||
--modelling.py
|
||||
--focal_loss.py
|
||||
--inference.py # 推理接口
|
||||
--train.py
|
||||
--test.py
|
||||
--statis_main.py
|
||||
--Qwen3-1.7B/
|
||||
--best_ckpt.pth
|
||||
```
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
from typing import List, Dict
|
||||
|
||||
valid_keys = [
|
||||
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
|
||||
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
||||
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
||||
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Last_Interaction", "Referral_Potential"
|
||||
]
|
||||
ch_valid_keys = [
|
||||
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
||||
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
||||
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
||||
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
||||
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
||||
"竞争者心态", "认知阶段", "最后互动时间", "推荐潜力"
|
||||
]
|
||||
all_keys = valid_keys + ["session_id", "label"]
|
||||
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
||||
@@ -22,7 +23,7 @@ d1_keys = valid_keys[:5]
|
||||
d2_keys = valid_keys[5:10]
|
||||
d3_keys = valid_keys[10:15]
|
||||
d4_keys = valid_keys[15:19]
|
||||
d5_keys = valid_keys[19:23]
|
||||
d5_keys = valid_keys[19:22]
|
||||
|
||||
def extract_json_files(folder: str):
|
||||
json_files = glob.glob(os.path.join(folder, "*.json"))
|
||||
@@ -47,17 +48,17 @@ def try_match_error_key(error_key: str):
|
||||
else:
|
||||
return None
|
||||
|
||||
def filt_json_data(json_data: dict):
|
||||
def filt_json_data(json_data: dict, threshold: int = 10):
|
||||
new_json_data = {}
|
||||
|
||||
for k, v in json_data.items():
|
||||
if len(v) >= 10:
|
||||
if len(v) >= threshold and len(v) != 0:
|
||||
new_json_data[k] = v
|
||||
print(f"Total {len(new_json_data)} json keys after filter")
|
||||
print(f"Total {len(new_json_data)} json data after filter with threshold {threshold}")
|
||||
return new_json_data
|
||||
|
||||
|
||||
def extract_json_data(json_files: list) -> dict:
|
||||
def extract_json_data(json_files: list, threshold: int = 10) -> dict:
|
||||
data = {}
|
||||
for json_file in json_files:
|
||||
session_id = os.path.basename(json_file).split(".")[0]
|
||||
@@ -66,24 +67,37 @@ def extract_json_data(json_files: list) -> dict:
|
||||
json_data = json.load(f)
|
||||
for key, value in json_data.items():
|
||||
if key in valid_keys:
|
||||
data[session_id][key] = value['value']
|
||||
data[session_id][key] = value.get("value", None)
|
||||
elif key == "Follow_up_Priority":
|
||||
continue
|
||||
else:
|
||||
match_key = try_match_error_key(key)
|
||||
if match_key:
|
||||
data[session_id][match_key] = value['value']
|
||||
data[session_id][match_key] = value.get("value", None)
|
||||
else:
|
||||
raise ValueError(f"Invalid key {key} in {json_file}")
|
||||
return filt_json_data(data)
|
||||
return filt_json_data(data, threshold)
|
||||
|
||||
def load_data_from_dict(data_dict: List[dict]):
|
||||
"""
|
||||
不进行阈值过滤,直接加载数据
|
||||
"""
|
||||
data = {}
|
||||
for idx, item in enumerate(data_dict):
|
||||
data[idx] = {}
|
||||
for key, value in item.items():
|
||||
if key in valid_keys:
|
||||
data[idx][key] = value.get("value", None)
|
||||
elif key == "Follow_up_Priority":
|
||||
continue
|
||||
else:
|
||||
match_key = try_match_error_key(key)
|
||||
if match_key:
|
||||
data[idx][match_key] = value.get("value", None)
|
||||
else:
|
||||
print(f"Warning: Invalid key {key} in data dict, skipped.")
|
||||
return data
|
||||
|
||||
if __name__=="__main__":
|
||||
deal_folder = "deal"
|
||||
not_deal_folder = "not_deal"
|
||||
|
||||
deal_json_files = extract_json_files(deal_folder)
|
||||
deal_data = extract_json_data(deal_json_files)
|
||||
deal_txt_files = extract_txt_files(deal_folder)
|
||||
|
||||
not_deal_json_files = extract_json_files(not_deal_folder)
|
||||
not_deal_data = extract_json_data(not_deal_json_files)
|
||||
not_deal_txt_files = extract_txt_files(not_deal_folder)
|
||||
|
||||
@@ -12,14 +12,14 @@ valid_keys = [
|
||||
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
||||
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
||||
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Last_Interaction", "Referral_Potential"
|
||||
]
|
||||
ch_valid_keys = [
|
||||
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
||||
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
||||
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
||||
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
||||
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
||||
"竞争者心态", "认知阶段", "最后互动时间", "推荐潜力"
|
||||
]
|
||||
all_keys = valid_keys + ["session_id", "label"]
|
||||
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
||||
@@ -27,7 +27,7 @@ d1_keys = valid_keys[:5]
|
||||
d2_keys = valid_keys[5:10]
|
||||
d3_keys = valid_keys[10:15]
|
||||
d4_keys = valid_keys[15:19]
|
||||
d5_keys = valid_keys[19:23]
|
||||
d5_keys = valid_keys[19:22]
|
||||
|
||||
class Formatter:
|
||||
def __init__(self, en2ch):
|
||||
@@ -39,41 +39,48 @@ class Formatter:
|
||||
sections.append("\n [痛感和焦虑等级]")
|
||||
for key in d1_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
|
||||
sections.append("\n [支付意愿与能力]")
|
||||
for key in d2_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
|
||||
sections.append("\n [成交阻力与防御机制]")
|
||||
for key in d3_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
|
||||
sections.append("\n [情绪钩子与成交切入点]")
|
||||
for key in d4_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
|
||||
sections.append("\n [客户生命周期状态]")
|
||||
for key in d5_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
return "\n".join(sections)
|
||||
|
||||
def get_llm_prompt(self, features):
|
||||
def get_llm_prompt(self, features: dict) -> list:
|
||||
user_profile = self._build_user_profile(features)
|
||||
|
||||
prompt = f"""
|
||||
你是一个销售心理学专家,请分析以下客户特征:
|
||||
prompt = prompt = f"""
|
||||
请分析以下客户特征,预测成交概率(0~1之间)。
|
||||
|
||||
{user_profile}
|
||||
|
||||
请提取客户的核心购买驱动力和主要障碍后分析该客户的成交概率。将成交概率以JSON格式输出:
|
||||
{{
|
||||
"conversion_probability": 0-1之间的数值
|
||||
}}
|
||||
成交概率:
|
||||
"""
|
||||
|
||||
messages = [
|
||||
@@ -82,9 +89,9 @@ class Formatter:
|
||||
return messages
|
||||
|
||||
class TransDataset(Dataset):
|
||||
def __init__(self, deal_data_folder, not_deal_data_folder):
|
||||
self.deal_data = extract_json_data(extract_json_files(deal_data_folder))
|
||||
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder))
|
||||
def __init__(self, deal_data_folder, not_deal_data_folder, threshold: int = 10, balance: bool = True):
|
||||
self.deal_data = extract_json_data(extract_json_files(deal_data_folder), threshold)
|
||||
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder), threshold)
|
||||
|
||||
self.formatter = Formatter(en2ch)
|
||||
|
||||
@@ -92,7 +99,7 @@ class TransDataset(Dataset):
|
||||
num_not_deal = len(self.not_deal_data)
|
||||
num_threshold = max(num_deal, num_not_deal) * 0.8
|
||||
|
||||
if not all([num_deal >= num_threshold, num_not_deal >= num_threshold]):
|
||||
if not all([num_deal >= num_threshold, num_not_deal >= num_threshold]) and balance:
|
||||
self._balance_samples()
|
||||
|
||||
self._build_samples()
|
||||
@@ -128,8 +135,30 @@ class TransDataset(Dataset):
|
||||
id, prompt, label = self.samples[idx]
|
||||
return id, prompt, label
|
||||
|
||||
def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size):
|
||||
dataset = TransDataset(deal_data_folder, not_deal_data_folder)
|
||||
class OfflineTransDataset(Dataset):
|
||||
def __init__(self, deal_data_folder, not_deal_data_folder, threshold: int = 10):
|
||||
self.deal_data = extract_json_data(extract_json_files(deal_data_folder), threshold)
|
||||
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder), threshold)
|
||||
|
||||
self.formatter = Formatter(en2ch)
|
||||
self.samples = []
|
||||
|
||||
for id, features in self.deal_data.items():
|
||||
messages = self.formatter.get_llm_prompt(features)
|
||||
self.samples.append((id, messages, 1))
|
||||
for id, features in self.not_deal_data.items():
|
||||
messages = self.formatter.get_llm_prompt(features)
|
||||
self.samples.append((id, messages, 0))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
id, prompt, label = self.samples[idx]
|
||||
return id, prompt, label
|
||||
|
||||
def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold: int = 10, balance: bool = True):
|
||||
dataset = TransDataset(deal_data_folder, not_deal_data_folder, threshold, balance)
|
||||
num_data = len(dataset)
|
||||
|
||||
train_size = int(0.8 * num_data)
|
||||
@@ -170,3 +199,20 @@ def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size):
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return {"train": train_loader, "val": val_loader, "test": test_loader}
|
||||
|
||||
def build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold: int = 10):
|
||||
dataset = OfflineTransDataset(deal_data_folder, not_deal_data_folder, threshold)
|
||||
|
||||
def collate_fn(batch):
|
||||
ids = [item[0] for item in batch]
|
||||
texts = [item[1] for item in batch]
|
||||
labels = torch.tensor([item[2] for item in batch], dtype=torch.long)
|
||||
return ids, texts, labels
|
||||
|
||||
offline_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return offline_loader
|
||||
101
inference.py
101
inference.py
@@ -1,11 +1,10 @@
|
||||
from model import TransClassifier
|
||||
from transformers import AutoTokenizer
|
||||
from data_process import extract_json_data, Formatter
|
||||
from data_process import extract_json_data, Formatter, load_data_from_dict
|
||||
import torch
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
@@ -14,14 +13,14 @@ valid_keys = [
|
||||
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
||||
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
||||
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Last_Interaction", "Referral_Potential"
|
||||
]
|
||||
ch_valid_keys = [
|
||||
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
||||
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
||||
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
||||
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
||||
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
||||
"竞争者心态", "认知阶段", "最后互动时间", "推荐潜力"
|
||||
]
|
||||
all_keys = valid_keys + ["session_id", "label"]
|
||||
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
||||
@@ -29,7 +28,7 @@ d1_keys = valid_keys[:5]
|
||||
d2_keys = valid_keys[5:10]
|
||||
d3_keys = valid_keys[10:15]
|
||||
d4_keys = valid_keys[15:19]
|
||||
d5_keys = valid_keys[19:23]
|
||||
d5_keys = valid_keys[19:22]
|
||||
|
||||
class InferenceEngine:
|
||||
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
|
||||
@@ -42,7 +41,7 @@ class InferenceEngine:
|
||||
print(f"Tokenizer loaded from {backbone_dir}")
|
||||
|
||||
# 加载模型
|
||||
self.model = TransClassifier(backbone_dir, device)
|
||||
self.model = TransClassifier(backbone_dir, 2, device)
|
||||
self.model.to(device)
|
||||
if self.ckpt_path:
|
||||
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
@@ -57,25 +56,17 @@ class InferenceEngine:
|
||||
def inference_batch(self, json_list: List[str]) -> dict:
|
||||
"""
|
||||
批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
||||
请注意Json文件中的词条数必须大于等于10.
|
||||
请注意Json文件中的词条数必须大于等于5.
|
||||
"""
|
||||
# print(111111)
|
||||
assert len(json_list) <= 10, "单次输入json文件数量不可超过8。"
|
||||
id2feature = extract_json_data(json_list)
|
||||
print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
|
||||
# id2feature
|
||||
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
||||
id2feature = extract_json_data(json_files=json_list, threshold=5)
|
||||
|
||||
message_list = []
|
||||
for id, feature in id2feature.items():
|
||||
messages = self.formatter.get_llm_prompt(feature)
|
||||
message_list.append(messages)
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
message_list,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = self.tokenizer.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
model_inputs = self.tokenizer(
|
||||
inputs,
|
||||
padding=True,
|
||||
@@ -87,21 +78,12 @@ class InferenceEngine:
|
||||
with torch.inference_mode():
|
||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||
outputs = self.model(model_inputs)
|
||||
|
||||
# 1. 计算分类标签(argmax)
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
|
||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
outputs_float = outputs.float()
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs] # 只保留类别1的概率
|
||||
|
||||
# 3. 计算置信度
|
||||
confidence = [abs(p - 0.5) * 2 for p in probs]
|
||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表
|
||||
return {"labels": preds, "probs": probs, "confidence": confidence}
|
||||
probs = [p[1] for p in probs]
|
||||
return {"labels": preds, "probs": probs}
|
||||
|
||||
def inference_sample(self, json_path: str) -> dict:
|
||||
"""
|
||||
@@ -110,23 +92,26 @@ class InferenceEngine:
|
||||
"""
|
||||
return self.inference_batch([json_path])
|
||||
|
||||
def inference(
|
||||
self,
|
||||
featurs : dict[str ,dict]
|
||||
):
|
||||
# assert len(featurs) <= 10, "单次输入json文件数量不可超过8。"
|
||||
# def inference(
|
||||
# self,
|
||||
# featurs : dict[str ,dict]
|
||||
# ):
|
||||
# # assert len(featurs) <= 10, "单次输入json文件数量不可超过8。"
|
||||
|
||||
def inference_batch_json_data(self, json_data: List[dict]) -> dict:
|
||||
"""
|
||||
批量推理函数,输入为 JSON 数据,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
||||
请注意Json文件中的词条数必须大于等于5. 但此处不进行过滤,请注意稍后对输出进行过滤。
|
||||
"""
|
||||
assert len(json_data) <= 8, "单次输入json数据数量不可超过8。"
|
||||
pseudo_id2feature = load_data_from_dict(json_data)
|
||||
|
||||
message_list = []
|
||||
for id, feature in featurs.items():
|
||||
for id, feature in pseudo_id2feature.items():
|
||||
messages = self.formatter.get_llm_prompt(feature)
|
||||
message_list.append(messages)
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
message_list,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
model_inputs = self.tokenizer(
|
||||
inputs,
|
||||
padding=True,
|
||||
@@ -138,26 +123,30 @@ class InferenceEngine:
|
||||
with torch.inference_mode():
|
||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||
outputs = self.model(model_inputs)
|
||||
|
||||
# 1. 计算分类标签(argmax)
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
|
||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
outputs_float = outputs.float()
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs] # 只保留类别1的概率
|
||||
|
||||
# 3. 计算置信度
|
||||
confidence = [abs(p - 0.5) * 2 for p in probs]
|
||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表
|
||||
return {"labels": preds, "probs": probs, "confidence": confidence}
|
||||
probs = [p[1] for p in probs]
|
||||
return {"labels": preds, "probs": probs}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
backbone_dir = "Qwen3-1.7B"
|
||||
ckpt_path = "best_ckpt.pth"
|
||||
device = "cuda"
|
||||
|
||||
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
||||
import glob
|
||||
deal_files = glob.glob(os.path.join("filtered_deal", "*.json"))
|
||||
test_deal_files = deal_files[:4]
|
||||
not_deal_files = glob.glob(os.path.join("filtered_not_deal", "*.json"))
|
||||
test_not_deal_files = not_deal_files[:4]
|
||||
|
||||
test_files = test_deal_files + test_not_deal_files
|
||||
test_dict = []
|
||||
for test_file in test_files:
|
||||
with open(test_file, "r", encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
test_dict.append(json_data)
|
||||
results = engine.inference_batch_json_data(test_dict)
|
||||
print(results)
|
||||
@@ -1 +1,2 @@
|
||||
from .modelling import TransClassifier
|
||||
from .focal_loss import FocalLoss
|
||||
135
model/focal_loss.py
Normal file
135
model/focal_loss.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
def __init__(self, gamma=2, alpha=None, reduction='mean', task_type='binary', num_classes=None):
|
||||
"""
|
||||
Unified Focal Loss class for binary, multi-class, and multi-label classification tasks.
|
||||
:param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma
|
||||
:param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used.
|
||||
:param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum'
|
||||
:param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label'
|
||||
:param num_classes: Number of classes (only required for multi-class classification)
|
||||
"""
|
||||
super(FocalLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = reduction
|
||||
self.task_type = task_type
|
||||
self.num_classes = num_classes
|
||||
|
||||
# Handle alpha for class balancing in multi-class tasks
|
||||
if task_type == 'multi-class' and alpha is not None and isinstance(alpha, (list, torch.Tensor)):
|
||||
assert num_classes is not None, "num_classes must be specified for multi-class classification"
|
||||
if isinstance(alpha, list):
|
||||
self.alpha = torch.Tensor(alpha)
|
||||
else:
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""
|
||||
Forward pass to compute the Focal Loss based on the specified task type.
|
||||
:param inputs: Predictions (logits) from the model.
|
||||
Shape:
|
||||
- binary/multi-label: (batch_size, num_classes)
|
||||
- multi-class: (batch_size, num_classes)
|
||||
:param targets: Ground truth labels.
|
||||
Shape:
|
||||
- binary: (batch_size,)
|
||||
- multi-label: (batch_size, num_classes)
|
||||
- multi-class: (batch_size,)
|
||||
"""
|
||||
if self.task_type == 'binary':
|
||||
return self.binary_focal_loss(inputs, targets)
|
||||
elif self.task_type == 'multi-class':
|
||||
return self.multi_class_focal_loss(inputs, targets)
|
||||
elif self.task_type == 'multi-label':
|
||||
return self.multi_label_focal_loss(inputs, targets)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'.")
|
||||
|
||||
def binary_focal_loss(self, inputs, targets):
|
||||
""" Focal loss for binary classification. """
|
||||
probs = torch.sigmoid(inputs)
|
||||
targets = targets.float()
|
||||
|
||||
# Compute binary cross entropy
|
||||
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
|
||||
|
||||
# Compute focal weight
|
||||
p_t = probs * targets + (1 - probs) * (1 - targets)
|
||||
focal_weight = (1 - p_t) ** self.gamma
|
||||
|
||||
# Apply alpha if provided
|
||||
if self.alpha is not None:
|
||||
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
|
||||
bce_loss = alpha_t * bce_loss
|
||||
|
||||
# Apply focal loss weighting
|
||||
loss = focal_weight * bce_loss
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return loss.sum()
|
||||
return loss
|
||||
|
||||
def multi_class_focal_loss(self, inputs, targets):
|
||||
""" Focal loss for multi-class classification. """
|
||||
if self.alpha is not None:
|
||||
alpha = self.alpha.to(inputs.device)
|
||||
|
||||
# Convert logits to probabilities with softmax
|
||||
probs = F.softmax(inputs, dim=1)
|
||||
|
||||
# One-hot encode the targets
|
||||
targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()
|
||||
|
||||
# Compute cross-entropy for each class
|
||||
ce_loss = -targets_one_hot * torch.log(probs)
|
||||
|
||||
# Compute focal weight
|
||||
p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample
|
||||
focal_weight = (1 - p_t) ** self.gamma
|
||||
|
||||
# Apply alpha if provided (per-class weighting)
|
||||
if self.alpha is not None:
|
||||
alpha_t = alpha.gather(0, targets)
|
||||
ce_loss = alpha_t.unsqueeze(1) * ce_loss
|
||||
|
||||
# Apply focal loss weight
|
||||
loss = focal_weight.unsqueeze(1) * ce_loss
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return loss.sum()
|
||||
return loss
|
||||
|
||||
def multi_label_focal_loss(self, inputs, targets):
|
||||
""" Focal loss for multi-label classification. """
|
||||
probs = torch.sigmoid(inputs)
|
||||
|
||||
# Compute binary cross entropy
|
||||
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
|
||||
|
||||
# Compute focal weight
|
||||
p_t = probs * targets + (1 - probs) * (1 - targets)
|
||||
focal_weight = (1 - p_t) ** self.gamma
|
||||
|
||||
# Apply alpha if provided
|
||||
if self.alpha is not None:
|
||||
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
|
||||
bce_loss = alpha_t * bce_loss
|
||||
|
||||
# Apply focal loss weight
|
||||
loss = focal_weight * bce_loss
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return loss.sum()
|
||||
return loss
|
||||
@@ -4,16 +4,18 @@ import torch.nn.functional as F
|
||||
from transformers import AutoModel
|
||||
|
||||
class TransClassifier(nn.Module):
|
||||
def __init__(self, model_dir: str, device: str="cuda"):
|
||||
def __init__(self, model_dir: str, output_classes: int, device: str="cuda"):
|
||||
super().__init__()
|
||||
self.backbone = AutoModel.from_pretrained(
|
||||
model_dir,
|
||||
dtype = "bfloat16"
|
||||
dtype = "bfloat16",
|
||||
attn_implementation="flash_attention_2"
|
||||
).to(device).eval()
|
||||
self.device = device
|
||||
self.torch_dtype = torch.bfloat16
|
||||
self.hidden_size = self.backbone.config.hidden_size
|
||||
|
||||
self.token_proj = nn.Linear(self.hidden_size, self.hidden_size).to(device=device, dtype=self.torch_dtype)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.LayerNorm(self.hidden_size),
|
||||
nn.Linear(self.hidden_size, self.hidden_size//2),
|
||||
@@ -22,7 +24,7 @@ class TransClassifier(nn.Module):
|
||||
nn.Linear(self.hidden_size//2, self.hidden_size//4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.hidden_size//4, 2)
|
||||
nn.Linear(self.hidden_size//4, output_classes)
|
||||
).to(device=device, dtype=self.torch_dtype)
|
||||
|
||||
for param in self.backbone.parameters():
|
||||
@@ -30,12 +32,15 @@ class TransClassifier(nn.Module):
|
||||
|
||||
def forward(self, model_inputs: dict):
|
||||
outputs = self.backbone(**model_inputs)
|
||||
proj_states = self.token_proj(outputs.last_hidden_state)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
# take last token hidden state
|
||||
cls_hidden_state = last_hidden_state[:, -1, :]
|
||||
attention_mask = model_inputs['attention_mask']
|
||||
mask_expanded = attention_mask.unsqueeze(-1).expand_as(proj_states).to(proj_states.dtype)
|
||||
sum_states = (proj_states * mask_expanded).sum(dim=1)
|
||||
valid_tokens = mask_expanded.sum(dim=1)
|
||||
pooled = sum_states / valid_tokens.clamp(min=1e-9)
|
||||
|
||||
logits = self.classifier(cls_hidden_state)
|
||||
logits = self.classifier(pooled)
|
||||
return logits
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
130
offline_test.py
Normal file
130
offline_test.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from data_process import build_offline_dataloader
|
||||
from model import TransClassifier
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
|
||||
from transformers import AutoTokenizer
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def offline_test(
|
||||
deal_data_folder, not_deal_data_folder, batch_size, threshold,
|
||||
backbone_dir, ckpt_path, device, filtered=False
|
||||
):
|
||||
offline_loader = build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||
model = TransClassifier(backbone_dir, 2, device)
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
model.eval()
|
||||
|
||||
all_ids = []
|
||||
all_preds = []
|
||||
all_probs = []
|
||||
all_labels = []
|
||||
|
||||
pbar = tqdm(offline_loader, desc="Testing")
|
||||
with torch.inference_mode():
|
||||
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||
all_ids.extend(ids)
|
||||
labels = labels.to(device)
|
||||
|
||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs]
|
||||
|
||||
all_preds.extend(preds)
|
||||
all_probs.extend(probs)
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
# 清理内存
|
||||
del texts, labels, outputs
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 计算评估指标
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
precision = precision_score(all_labels, all_preds, average="weighted")
|
||||
recall = recall_score(all_labels, all_preds, average="weighted")
|
||||
f1 = f1_score(all_labels, all_preds, average="weighted")
|
||||
auc = roc_auc_score(all_labels, all_probs)
|
||||
cm = confusion_matrix(all_labels, all_preds)
|
||||
|
||||
precision_per_class = precision_score(all_labels, all_preds, average=None)
|
||||
recall_per_class = recall_score(all_labels, all_preds, average=None)
|
||||
f1_per_class = f1_score(all_labels, all_preds, average=None)
|
||||
|
||||
test_results = {
|
||||
"accuracy": accuracy,
|
||||
"precision_weighted": precision,
|
||||
"recall_weighted": recall,
|
||||
"f1_weighted": f1,
|
||||
"auc": auc,
|
||||
"confusion_matrix": cm.tolist(),
|
||||
"class_0_precision": precision_per_class[0],
|
||||
"class_0_recall": recall_per_class[0],
|
||||
"class_0_f1": f1_per_class[0],
|
||||
"class_1_precision": precision_per_class[1],
|
||||
"class_1_recall": recall_per_class[1],
|
||||
"class_1_f1": f1_per_class[1]
|
||||
}
|
||||
|
||||
if filtered:
|
||||
with open(f"offline_test_result_filtered_{threshold}.json", "w", encoding="utf-8") as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
||||
else:
|
||||
with open(f"offline_test_result_{threshold}.json", "w", encoding="utf-8") as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
||||
|
||||
pred_df = pd.DataFrame({
|
||||
"ids": all_ids,
|
||||
"predictions": all_preds,
|
||||
"probability": all_probs,
|
||||
"true_labels": all_labels
|
||||
})
|
||||
if filtered:
|
||||
pred_df.to_csv(f"offline_test_predictions_filtered_{threshold}.csv", index=False, encoding="utf-8")
|
||||
else:
|
||||
pred_df.to_csv(f"offline_test_predictions_{threshold}.csv", index=False, encoding="utf-8")
|
||||
|
||||
if __name__=="__main__":
|
||||
filtered_deal_folder = "not_trained_deal_filtered"
|
||||
filtered_not_deal_folder = "not_trained_not_deal_filtered"
|
||||
|
||||
deal_folder = "not_trained_deal"
|
||||
not_deal_folder = "not_trained_not_deal"
|
||||
|
||||
batch_size = 8
|
||||
backbone_dir = "Qwen3-1.7B"
|
||||
|
||||
""" for i in range(3, 9):
|
||||
threshold = i
|
||||
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False)
|
||||
offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True) """
|
||||
|
||||
threshold = 5
|
||||
ckpt_path = f"best_ckpt_threshold_{threshold}_1st.pth"
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
offline_test(deal_folder, not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=False)
|
||||
offline_test(filtered_deal_folder, filtered_not_deal_folder, batch_size, threshold, backbone_dir, ckpt_path, device, filtered=True)
|
||||
154
test.py
154
test.py
@@ -1,154 +0,0 @@
|
||||
from data_process import build_dataloader
|
||||
from model import TransClassifier
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
import gc
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
def test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path="best_ckpt.pth", device="cuda"):
|
||||
"""
|
||||
测试模型在测试集上的表现
|
||||
|
||||
Args:
|
||||
backbone_dir: 预训练模型目录
|
||||
deal_folder: 成交数据文件夹
|
||||
not_deal_folder: 非成交数据文件夹
|
||||
batch_size: 批量大小
|
||||
ckpt_path: 模型 checkpoint 路径
|
||||
device: 运行设备
|
||||
"""
|
||||
# 加载测试数据
|
||||
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
|
||||
test_loader = data_dict["test"]
|
||||
print(f"Test data loaded successfully. Test samples: {len(test_loader.dataset)}")
|
||||
|
||||
# 加载 tokenizer 和模型
|
||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||
model = TransClassifier(backbone_dir, device)
|
||||
model.to(device)
|
||||
|
||||
# 加载训练好的模型权重
|
||||
if os.path.exists(ckpt_path):
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
print(f"Model loaded from {ckpt_path}")
|
||||
else:
|
||||
print(f"Warning: {ckpt_path} not found. Using untrained model.")
|
||||
|
||||
# 测试模型
|
||||
model.eval()
|
||||
all_ids = []
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
test_loss = 0.0
|
||||
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
pbar = tqdm(test_loader, desc="Testing")
|
||||
with torch.inference_mode():
|
||||
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||
all_ids.extend(ids)
|
||||
labels = labels.to(device)
|
||||
|
||||
# 处理输入数据
|
||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
loss = loss_func(outputs, labels)
|
||||
|
||||
test_loss += loss.item()
|
||||
|
||||
# 计算预测结果
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy()
|
||||
all_preds.extend(preds)
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
# 清理内存
|
||||
del texts, labels, outputs, loss
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 计算评估指标
|
||||
avg_loss = test_loss / len(test_loader)
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
precision = precision_score(all_labels, all_preds, average="weighted")
|
||||
recall = recall_score(all_labels, all_preds, average="weighted")
|
||||
f1 = f1_score(all_labels, all_preds, average="weighted")
|
||||
cm = confusion_matrix(all_labels, all_preds)
|
||||
|
||||
# 打印评估结果
|
||||
print("\n=== Test Results ===")
|
||||
print(f"Average Loss: {avg_loss:.4f}")
|
||||
print(f"Accuracy: {accuracy:.4f}")
|
||||
print(f"Precision: {precision:.4f}")
|
||||
print(f"Recall: {recall:.4f}")
|
||||
print(f"F1 Score: {f1:.4f}")
|
||||
print("\nConfusion Matrix:")
|
||||
print(cm)
|
||||
print("\n=== Class-wise Metrics ===")
|
||||
print("Class 0 (Not Deal):")
|
||||
print(f" Precision: {precision_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||
print(f" Recall: {recall_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||
print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[0]:.4f}")
|
||||
print("\nClass 1 (Deal):")
|
||||
print(f" Precision: {precision_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||
print(f" Recall: {recall_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||
print(f" F1 Score: {f1_score(all_labels, all_preds, average=None)[1]:.4f}")
|
||||
|
||||
# 保存测试结果
|
||||
test_results = {
|
||||
"average_loss": avg_loss,
|
||||
"accuracy": accuracy,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1,
|
||||
"confusion_matrix": cm.tolist(),
|
||||
"class_0_precision": precision_score(all_labels, all_preds, average=None)[0],
|
||||
"class_0_recall": recall_score(all_labels, all_preds, average=None)[0],
|
||||
"class_0_f1": f1_score(all_labels, all_preds, average=None)[0],
|
||||
"class_1_precision": precision_score(all_labels, all_preds, average=None)[1],
|
||||
"class_1_recall": recall_score(all_labels, all_preds, average=None)[1],
|
||||
"class_1_f1": f1_score(all_labels, all_preds, average=None)[1],
|
||||
"test_samples": len(all_labels),
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
# 保存预测结果
|
||||
pred_results = {
|
||||
"ids": all_ids,
|
||||
"predictions": all_preds,
|
||||
"true_labels": all_labels
|
||||
}
|
||||
pred_df = pd.DataFrame(pred_results)
|
||||
pred_df.to_csv("test_predictions.csv", index=False, encoding="utf-8")
|
||||
|
||||
# 保存为 JSON 文件
|
||||
with open("test_results.json", "w", encoding="utf-8") as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=2)
|
||||
print("\nTest results saved to test_results.json")
|
||||
return test_results
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
|
||||
deal_folder = "deal"
|
||||
not_deal_folder = "not_deal"
|
||||
batch_size = 8
|
||||
ckpt_path = "best_ckpt.pth"
|
||||
device = "cuda"
|
||||
|
||||
# 运行测试
|
||||
test(backbone_dir, deal_folder, not_deal_folder, batch_size, ckpt_path, device)
|
||||
238
train.py
238
train.py
@@ -1,5 +1,5 @@
|
||||
from data_process import build_dataloader
|
||||
from model import TransClassifier
|
||||
from model import TransClassifier, FocalLoss
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,24 +14,25 @@ from datetime import datetime
|
||||
import gc
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience=5, delta=0, path='checkpoint.pt'):
|
||||
def __init__(self, patience=5, delta=0):
|
||||
self.patience = patience
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = np.inf
|
||||
self.delta = delta
|
||||
self.path = path
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
def __call__(self, val_loss, model, best_ckpt_path):
|
||||
score = -val_loss
|
||||
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.save_checkpoint(val_loss, model, best_ckpt_path)
|
||||
|
||||
elif score < self.best_score + self.delta:
|
||||
self.counter += 1
|
||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
@@ -39,32 +40,49 @@ class EarlyStopping:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.save_checkpoint(val_loss, model, best_ckpt_path)
|
||||
self.counter = 0
|
||||
|
||||
def save_checkpoint(self, val_loss, model):
|
||||
def save_checkpoint(self, val_loss, model, best_ckpt_path):
|
||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
||||
torch.save(model.state_dict(), self.path)
|
||||
torch.save(model.state_dict(), best_ckpt_path)
|
||||
self.val_loss_min = val_loss
|
||||
|
||||
def train(backbone_dir, deal_folder, not_deal_folder,
|
||||
batch_size, initial_lr=1e-5, max_epochs=100,
|
||||
best_ckpt_path="best_ckpt.pth", final_ckpt_path="final_ckpt.pth", device="cuda"):
|
||||
|
||||
data_dict = build_dataloader(deal_folder, not_deal_folder, batch_size)
|
||||
train_loader = data_dict["train"]
|
||||
val_loader = data_dict["val"]
|
||||
batch_size, initial_lr=1e-5, max_epochs=100, threshold: int = 10,
|
||||
device="cuda", use_focal_loss=False, balance=True):
|
||||
best_ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||
model = TransClassifier(backbone_dir, device)
|
||||
if use_focal_loss:
|
||||
model = TransClassifier(backbone_dir, output_classes=1, device=device)
|
||||
if balance:
|
||||
loss_func = FocalLoss(
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
reduction='mean',
|
||||
task_type='binary')
|
||||
else:
|
||||
loss_func = FocalLoss(
|
||||
gamma=2.0,
|
||||
alpha=0.8,
|
||||
reduction='mean',
|
||||
task_type='binary')
|
||||
else:
|
||||
model = TransClassifier(backbone_dir, output_classes=2, device=device)
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
assert balance == True, "When not using CE loss, balance must be True."
|
||||
model.to(device)
|
||||
|
||||
data_dict = build_dataloader(deal_data_folder=deal_folder, not_deal_data_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, balance=balance)
|
||||
train_loader = data_dict["train"]
|
||||
val_loader = data_dict["val"]
|
||||
test_loader = data_dict["test"]
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
||||
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
early_stopping = EarlyStopping(path=best_ckpt_path)
|
||||
early_stopping = EarlyStopping(patience=10, delta=0)
|
||||
history = {"train_loss": [], "val_loss": [], "epoch": []}
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
@@ -83,8 +101,13 @@ def train(backbone_dir, deal_folder, not_deal_folder,
|
||||
|
||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
if use_focal_loss:
|
||||
outputs = outputs.squeeze(1)
|
||||
loss = loss_func(outputs, labels.float())
|
||||
else:
|
||||
loss = loss_func(outputs, labels)
|
||||
|
||||
optimizer.zero_grad()
|
||||
@@ -103,26 +126,25 @@ def train(backbone_dir, deal_folder, not_deal_folder,
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
val_loss = val(val_loader, model, loss_func, tokenizer, device)
|
||||
history["train_loss"].append(total_loss / len(train_loader))
|
||||
val_loss = val(val_loader, model, loss_func, tokenizer, device, use_focal_loss)
|
||||
train_loss_epoch = total_loss / len(train_loader)
|
||||
history["train_loss"].append(train_loss_epoch)
|
||||
history["val_loss"].append(val_loss)
|
||||
history["epoch"].append(epoch+1)
|
||||
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
|
||||
|
||||
early_stopping(val_loss, model)
|
||||
early_stopping(val_loss, model, best_ckpt_path)
|
||||
if early_stopping.early_stop:
|
||||
print("Early stopping")
|
||||
break
|
||||
|
||||
torch.save(model.state_dict(), final_ckpt_path)
|
||||
print(f"Final model saved to {final_ckpt_path}")
|
||||
|
||||
history_df = pd.DataFrame(history)
|
||||
history_df.to_csv("training_history.csv", index=False)
|
||||
print("Training history saved to training_history.csv")
|
||||
history_df.to_csv(f"training_history_threshold_{threshold}.csv", index=False)
|
||||
print(f"Training history saved to training_history_threshold_{threshold}.csv")
|
||||
return test_loader
|
||||
|
||||
def val(val_loader, model, loss_func, tokenizer, device):
|
||||
def val(val_loader, model, loss_func, tokenizer, device, use_focal_loss=False):
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
@@ -133,17 +155,169 @@ def val(val_loader, model, loss_func, tokenizer, device):
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
if use_focal_loss:
|
||||
outputs = outputs.squeeze(1)
|
||||
loss = loss_func(outputs, labels.float())
|
||||
else:
|
||||
loss = loss_func(outputs, labels)
|
||||
|
||||
val_loss += loss.item()
|
||||
|
||||
del inputs, outputs, labels, loss
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return val_loss / len(val_loader)
|
||||
|
||||
def test(backbone_dir, test_loader, device, threshold, use_focal_loss=False, balance=True):
|
||||
tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||
|
||||
if use_focal_loss:
|
||||
model = TransClassifier(backbone_dir, output_classes=1, device=device)
|
||||
else:
|
||||
model = TransClassifier(backbone_dir, output_classes=2, device=device)
|
||||
model.to(device)
|
||||
|
||||
ckpt_path = f"best_ckpt_threshold_{threshold}.pth"
|
||||
if os.path.exists(ckpt_path):
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
print(f"Model loaded from {ckpt_path}")
|
||||
else:
|
||||
print(f"Warning: {ckpt_path} not found. Using untrained model.")
|
||||
|
||||
model.eval()
|
||||
|
||||
all_ids = []
|
||||
all_preds = []
|
||||
all_probs = []
|
||||
all_labels = []
|
||||
|
||||
pbar = tqdm(test_loader, desc="Testing")
|
||||
with torch.inference_mode():
|
||||
for batch_idx, (ids, texts, labels) in enumerate(pbar):
|
||||
all_ids.extend(ids)
|
||||
labels = labels.to(device)
|
||||
|
||||
texts = tokenizer.apply_chat_template(texts, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(device)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
outputs = model(inputs)
|
||||
if use_focal_loss:
|
||||
outputs = outputs.squeeze(-1) # [B, 1] -> [B]
|
||||
|
||||
if use_focal_loss:
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.sigmoid(outputs_float).cpu().numpy().tolist() # [B]
|
||||
preds = [1 if p >= 0.5 else 0 for p in probs]
|
||||
else:
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs]
|
||||
|
||||
all_preds.extend(preds)
|
||||
all_probs.extend(probs)
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
# 清理内存
|
||||
del texts, labels, outputs
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 计算评估指标
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
precision = precision_score(all_labels, all_preds, average="weighted")
|
||||
recall = recall_score(all_labels, all_preds, average="weighted")
|
||||
f1 = f1_score(all_labels, all_preds, average="weighted")
|
||||
auc = roc_auc_score(all_labels, all_probs)
|
||||
cm = confusion_matrix(all_labels, all_preds)
|
||||
|
||||
# 打印评估结果
|
||||
print("\n=== Test Results ===")
|
||||
print(f"Accuracy: {accuracy:.4f}")
|
||||
print(f"Precision: {precision:.4f}")
|
||||
print(f"Recall: {recall:.4f}")
|
||||
print(f"F1 Score: {f1:.4f}")
|
||||
print(f"AUC: {auc:.4f}")
|
||||
|
||||
cm_df = pd.DataFrame(cm,
|
||||
index=['Actual Not Deal (0)', 'Actual Deal (1)'],
|
||||
columns=['Predicted Not Deal (0)', 'Predicted Deal (1)'])
|
||||
print("\nConfusion Matrix:")
|
||||
print(cm_df)
|
||||
|
||||
precision_per_class = precision_score(all_labels, all_preds, average=None)
|
||||
recall_per_class = recall_score(all_labels, all_preds, average=None)
|
||||
f1_per_class = f1_score(all_labels, all_preds, average=None)
|
||||
print("\n=== Class-wise Metrics ===")
|
||||
print("Class 0 (Not Deal):")
|
||||
print(f" Precision: {precision_per_class[0]:.4f}")
|
||||
print(f" Recall: {recall_per_class[0]:.4f}")
|
||||
print(f" F1 Score: {f1_per_class[0]:.4f}")
|
||||
print("\nClass 1 (Deal):")
|
||||
print(f" Precision: {precision_per_class[1]:.4f}")
|
||||
print(f" Recall: {recall_per_class[1]:.4f}")
|
||||
print(f" F1 Score: {f1_per_class[1]:.4f}")
|
||||
|
||||
test_results = {
|
||||
"accuracy": accuracy,
|
||||
"precision_weighted": precision,
|
||||
"recall_weighted": recall,
|
||||
"f1_weighted": f1,
|
||||
"auc": auc,
|
||||
"confusion_matrix": cm.tolist(),
|
||||
"class_0_precision": precision_per_class[0],
|
||||
"class_0_recall": recall_per_class[0],
|
||||
"class_0_f1": f1_per_class[0],
|
||||
"class_1_precision": precision_per_class[1],
|
||||
"class_1_recall": recall_per_class[1],
|
||||
"class_1_f1": f1_per_class[1],
|
||||
"test_samples": len(all_labels)
|
||||
}
|
||||
|
||||
with open(f"test_results_threshold_{threshold}.json", "w", encoding="utf-8") as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
||||
print(f"\nTest results saved to test_results_threshold_{threshold}.json")
|
||||
|
||||
pred_df = pd.DataFrame({
|
||||
"ids": all_ids,
|
||||
"predictions": all_preds,
|
||||
"probability": all_probs,
|
||||
"true_labels": all_labels
|
||||
})
|
||||
pred_df.to_csv(f"test_predictions_threshold_{threshold}.csv", index=False, encoding="utf-8")
|
||||
|
||||
if __name__ == "__main__":
|
||||
backbone_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
|
||||
deal_folder = "deal"
|
||||
not_deal_folder = "not_deal"
|
||||
batch_size = 8
|
||||
backbone_dir = "Qwen3-1.7B"
|
||||
deal_folder = "filtered_deal"
|
||||
not_deal_folder = "filtered_not_deal"
|
||||
batch_size = 4
|
||||
device = "cuda"
|
||||
|
||||
""" threshold = 10
|
||||
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=threshold, device=device, use_focal_loss=False, balance=True)
|
||||
|
||||
train(backbone_dir, deal_folder, not_deal_folder, batch_size, device=device)
|
||||
test(
|
||||
backbone_dir=backbone_dir,
|
||||
test_loader=test_loader,
|
||||
device=device,
|
||||
threshold=threshold,
|
||||
use_focal_loss=False,
|
||||
balance=True
|
||||
) """
|
||||
|
||||
max_threshold = 10
|
||||
for i in range(3, 9):
|
||||
print(f"Training with threshold {i}...")
|
||||
test_loader = train(backbone_dir=backbone_dir, deal_folder=deal_folder, not_deal_folder=not_deal_folder, batch_size=batch_size, threshold=i, device=device, use_focal_loss=False, balance=True)
|
||||
|
||||
test(
|
||||
backbone_dir=backbone_dir,
|
||||
test_loader=test_loader,
|
||||
device=device,
|
||||
threshold=i,
|
||||
use_focal_loss=False,
|
||||
balance=True
|
||||
)
|
||||
29
visualize_training.py
Normal file
29
visualize_training.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def visualize_training_history(threshold):
|
||||
csv_path = f'training_history_threshold_{threshold}.csv'
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
epochs = df['epoch']
|
||||
train_loss = df['train_loss']
|
||||
val_loss = df['val_loss']
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
|
||||
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
|
||||
|
||||
plt.title('Training and Validation Loss Over Epochs (CE)')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.savefig(f'training_visualization_threshold_{threshold}.png')
|
||||
plt.show()
|
||||
|
||||
print(f'可视化完成,图表已保存为 training_visualization_threshold_{threshold}.png')
|
||||
|
||||
if __name__ == "__main__":
|
||||
for i in range(11):
|
||||
visualize_training_history(i)
|
||||
Reference in New Issue
Block a user