From f3959031648720894e92daa8e82df3b2b8fbbd76 Mon Sep 17 00:00:00 2001 From: WangZiFan Date: Fri, 27 Feb 2026 11:36:56 +0800 Subject: [PATCH] Update data_process/process/preprocess.py --- data_process/process/preprocess.py | 390 ++++++++++++++++------------- 1 file changed, 218 insertions(+), 172 deletions(-) diff --git a/data_process/process/preprocess.py b/data_process/process/preprocess.py index aeb4eed..e55d65c 100644 --- a/data_process/process/preprocess.py +++ b/data_process/process/preprocess.py @@ -1,172 +1,218 @@ -import random -import numpy as np - -from torch.utils.data import Dataset, DataLoader, random_split -import torch -from transformers import AutoTokenizer - -from .content_extract import extract_json_files, extract_json_data - -valid_keys = [ - "Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index", - "Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity", - "Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage", - "Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus", - "Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential" -] -ch_valid_keys = [ - "核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数", - "社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度", - "沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏", - "低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成", - "竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力" -] -all_keys = valid_keys + ["session_id", "label"] -en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)} -d1_keys = valid_keys[:5] -d2_keys = valid_keys[5:10] -d3_keys = valid_keys[10:15] -d4_keys = valid_keys[15:19] -d5_keys = valid_keys[19:23] - -class Formatter: - def __init__(self, en2ch): - self.en2ch = en2ch - - def _build_user_profile(self, profile: dict) -> str: - sections = [] - sections.append("[客户画像]") - sections.append("\n [痛感和焦虑等级]") - for key in d1_keys: - if key in profile: - sections.append(f"{self.en2ch[key]}: {profile[key]}") - - sections.append("\n [支付意愿与能力]") - for key in d2_keys: - if key in profile: - sections.append(f"{self.en2ch[key]}: {profile[key]}") - - sections.append("\n [成交阻力与防御机制]") - for key in d3_keys: - if key in profile: - sections.append(f"{self.en2ch[key]}: {profile[key]}") - - sections.append("\n [情绪钩子与成交切入点]") - for key in d4_keys: - if key in profile: - sections.append(f"{self.en2ch[key]}: {profile[key]}") - - sections.append("\n [客户生命周期状态]") - for key in d5_keys: - if key in profile: - sections.append(f"{self.en2ch[key]}: {profile[key]}") - return "\n".join(sections) - - def get_llm_prompt(self, features): - user_profile = self._build_user_profile(features) - - prompt = f""" - 你是一个销售心理学专家,请分析以下客户特征: - - {user_profile} - - 请提取客户的核心购买驱动力和主要障碍后分析该客户的成交概率。将成交概率以JSON格式输出: - {{ - "conversion_probability": 0-1之间的数值 - }} - """ - - messages = [ - {"role": "user", "content": prompt} - ] - return messages - -class TransDataset(Dataset): - def __init__(self, deal_data_folder, not_deal_data_folder): - self.deal_data = extract_json_data(extract_json_files(deal_data_folder)) - self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder)) - - self.formatter = Formatter(en2ch) - - num_deal = len(self.deal_data) - num_not_deal = len(self.not_deal_data) - num_threshold = max(num_deal, num_not_deal) * 0.8 - - if not all([num_deal >= num_threshold, num_not_deal >= num_threshold]): - self._balance_samples() - - self._build_samples() - - def _build_samples(self): - self.samples = [] - - for id, features in self.deal_data.items(): - messages = self.formatter.get_llm_prompt(features) - self.samples.append((id, messages, 1)) - for id, features in self.not_deal_data.items(): - messages = self.formatter.get_llm_prompt(features) - self.samples.append((id, messages, 0)) - - random.shuffle(self.samples) - print(f"total samples num: {len(self.samples)}, deal num: {len(self.deal_data)}, not deal num: {len(self.not_deal_data)}") - - def _balance_samples(self): - random.seed(42) - np.random.seed(42) - - not_deal_ids = list(self.not_deal_data.keys()) - target_size = len(self.deal_data) - - if len(not_deal_ids) > target_size: - selected_not_deal_ids = random.sample(not_deal_ids, target_size) - self.not_deal_data = {sid: self.not_deal_data[sid] for sid in selected_not_deal_ids} - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - id, prompt, label = self.samples[idx] - return id, prompt, label - -def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size): - dataset = TransDataset(deal_data_folder, not_deal_data_folder) - num_data = len(dataset) - - train_size = int(0.8 * num_data) - val_size = int(0.1 * num_data) - test_size = num_data - train_size - val_size - - print(f"train size: {train_size}") - print(f"val size: {val_size}") - print(f"test size: {test_size}") - train_dataset, val_dataset, test_dataset = random_split( - dataset, - [train_size, val_size, test_size], - generator=torch.Generator().manual_seed(42) - ) - - def collate_fn(batch): - ids = [item[0] for item in batch] - texts = [item[1] for item in batch] - labels = torch.tensor([item[2] for item in batch], dtype=torch.long) - return ids, texts, labels - - train_loader = DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=True, - collate_fn=collate_fn - ) - val_loader = DataLoader( - val_dataset, - batch_size=batch_size, - shuffle=False, - collate_fn=collate_fn - ) - test_loader = DataLoader( - test_dataset, - batch_size=batch_size, - shuffle=False, - collate_fn=collate_fn - ) - return {"train": train_loader, "val": val_loader, "test": test_loader} \ No newline at end of file +import random +import numpy as np + +from torch.utils.data import Dataset, DataLoader, random_split +import torch +from transformers import AutoTokenizer + +from .content_extract import extract_json_files, extract_json_data + +valid_keys = [ + "Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index", + "Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity", + "Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage", + "Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus", + "Competitor_Mindset", "Cognitive_Stage", "Last_Interaction", "Referral_Potential" +] +ch_valid_keys = [ + "核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数", + "社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度", + "沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏", + "低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成", + "竞争者心态", "认知阶段", "最后互动时间", "推荐潜力" +] +all_keys = valid_keys + ["session_id", "label"] +en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)} +d1_keys = valid_keys[:5] +d2_keys = valid_keys[5:10] +d3_keys = valid_keys[10:15] +d4_keys = valid_keys[15:19] +d5_keys = valid_keys[19:22] + +class Formatter: + def __init__(self, en2ch): + self.en2ch = en2ch + + def _build_user_profile(self, profile: dict) -> str: + sections = [] + sections.append("[客户画像]") + sections.append("\n [痛感和焦虑等级]") + for key in d1_keys: + if key in profile: + if profile[key] is None: + continue + sections.append(f"{self.en2ch[key]}: {profile[key]}") + + sections.append("\n [支付意愿与能力]") + for key in d2_keys: + if key in profile: + if profile[key] is None: + continue + sections.append(f"{self.en2ch[key]}: {profile[key]}") + + sections.append("\n [成交阻力与防御机制]") + for key in d3_keys: + if key in profile: + if profile[key] is None: + continue + sections.append(f"{self.en2ch[key]}: {profile[key]}") + + sections.append("\n [情绪钩子与成交切入点]") + for key in d4_keys: + if key in profile: + if profile[key] is None: + continue + sections.append(f"{self.en2ch[key]}: {profile[key]}") + + sections.append("\n [客户生命周期状态]") + for key in d5_keys: + if key in profile: + if profile[key] is None: + continue + sections.append(f"{self.en2ch[key]}: {profile[key]}") + return "\n".join(sections) + + def get_llm_prompt(self, features: dict) -> list: + user_profile = self._build_user_profile(features) + + prompt = prompt = f""" + 请分析以下客户特征,预测成交概率(0~1之间)。 + + {user_profile} + + 成交概率: + """ + + messages = [ + {"role": "user", "content": prompt} + ] + return messages + +class TransDataset(Dataset): + def __init__(self, deal_data_folder, not_deal_data_folder, threshold: int = 10, balance: bool = True): + self.deal_data = extract_json_data(extract_json_files(deal_data_folder), threshold) + self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder), threshold) + + self.formatter = Formatter(en2ch) + + num_deal = len(self.deal_data) + num_not_deal = len(self.not_deal_data) + num_threshold = max(num_deal, num_not_deal) * 0.8 + + if not all([num_deal >= num_threshold, num_not_deal >= num_threshold]) and balance: + self._balance_samples() + + self._build_samples() + + def _build_samples(self): + self.samples = [] + + for id, features in self.deal_data.items(): + messages = self.formatter.get_llm_prompt(features) + self.samples.append((id, messages, 1)) + for id, features in self.not_deal_data.items(): + messages = self.formatter.get_llm_prompt(features) + self.samples.append((id, messages, 0)) + + random.shuffle(self.samples) + print(f"total samples num: {len(self.samples)}, deal num: {len(self.deal_data)}, not deal num: {len(self.not_deal_data)}") + + def _balance_samples(self): + random.seed(42) + np.random.seed(42) + + not_deal_ids = list(self.not_deal_data.keys()) + target_size = len(self.deal_data) + + if len(not_deal_ids) > target_size: + selected_not_deal_ids = random.sample(not_deal_ids, target_size) + self.not_deal_data = {sid: self.not_deal_data[sid] for sid in selected_not_deal_ids} + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + id, prompt, label = self.samples[idx] + return id, prompt, label + +class OfflineTransDataset(Dataset): + def __init__(self, deal_data_folder, not_deal_data_folder, threshold: int = 10): + self.deal_data = extract_json_data(extract_json_files(deal_data_folder), threshold) + self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder), threshold) + + self.formatter = Formatter(en2ch) + self.samples = [] + + for id, features in self.deal_data.items(): + messages = self.formatter.get_llm_prompt(features) + self.samples.append((id, messages, 1)) + for id, features in self.not_deal_data.items(): + messages = self.formatter.get_llm_prompt(features) + self.samples.append((id, messages, 0)) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + id, prompt, label = self.samples[idx] + return id, prompt, label + +def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold: int = 10, balance: bool = True): + dataset = TransDataset(deal_data_folder, not_deal_data_folder, threshold, balance) + num_data = len(dataset) + + train_size = int(0.8 * num_data) + val_size = int(0.1 * num_data) + test_size = num_data - train_size - val_size + + print(f"train size: {train_size}") + print(f"val size: {val_size}") + print(f"test size: {test_size}") + train_dataset, val_dataset, test_dataset = random_split( + dataset, + [train_size, val_size, test_size], + generator=torch.Generator().manual_seed(42) + ) + + def collate_fn(batch): + ids = [item[0] for item in batch] + texts = [item[1] for item in batch] + labels = torch.tensor([item[2] for item in batch], dtype=torch.long) + return ids, texts, labels + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn + ) + test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn + ) + return {"train": train_loader, "val": val_loader, "test": test_loader} + +def build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold: int = 10): + dataset = OfflineTransDataset(deal_data_folder, not_deal_data_folder, threshold) + + def collate_fn(batch): + ids = [item[0] for item in batch] + texts = [item[1] for item in batch] + labels = torch.tensor([item[2] for item in batch], dtype=torch.long) + return ids, texts, labels + + offline_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn + ) + return offline_loader \ No newline at end of file