Files
deal-classification/data_process/process/preprocess.py

172 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from transformers import AutoTokenizer
from .content_extract import extract_json_files, extract_json_data
valid_keys = [
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
]
ch_valid_keys = [
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
]
all_keys = valid_keys + ["session_id", "label"]
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
d1_keys = valid_keys[:5]
d2_keys = valid_keys[5:10]
d3_keys = valid_keys[10:15]
d4_keys = valid_keys[15:19]
d5_keys = valid_keys[19:23]
class Formatter:
def __init__(self, en2ch):
self.en2ch = en2ch
def _build_user_profile(self, profile: dict) -> str:
sections = []
sections.append("[客户画像]")
sections.append("\n [痛感和焦虑等级]")
for key in d1_keys:
if key in profile:
sections.append(f"{self.en2ch[key]}: {profile[key]}")
sections.append("\n [支付意愿与能力]")
for key in d2_keys:
if key in profile:
sections.append(f"{self.en2ch[key]}: {profile[key]}")
sections.append("\n [成交阻力与防御机制]")
for key in d3_keys:
if key in profile:
sections.append(f"{self.en2ch[key]}: {profile[key]}")
sections.append("\n [情绪钩子与成交切入点]")
for key in d4_keys:
if key in profile:
sections.append(f"{self.en2ch[key]}: {profile[key]}")
sections.append("\n [客户生命周期状态]")
for key in d5_keys:
if key in profile:
sections.append(f"{self.en2ch[key]}: {profile[key]}")
return "\n".join(sections)
def get_llm_prompt(self, features):
user_profile = self._build_user_profile(features)
prompt = f"""
你是一个销售心理学专家,请分析以下客户特征:
{user_profile}
请提取客户的核心购买驱动力和主要障碍后分析该客户的成交概率。将成交概率以JSON格式输出
{{
"conversion_probability": 0-1之间的数值
}}
"""
messages = [
{"role": "user", "content": prompt}
]
return messages
class TransDataset(Dataset):
def __init__(self, deal_data_folder, not_deal_data_folder):
self.deal_data = extract_json_data(extract_json_files(deal_data_folder))
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder))
self.formatter = Formatter(en2ch)
num_deal = len(self.deal_data)
num_not_deal = len(self.not_deal_data)
num_threshold = max(num_deal, num_not_deal) * 0.8
if not all([num_deal >= num_threshold, num_not_deal >= num_threshold]):
self._balance_samples()
self._build_samples()
def _build_samples(self):
self.samples = []
for id, features in self.deal_data.items():
messages = self.formatter.get_llm_prompt(features)
self.samples.append((id, messages, 1))
for id, features in self.not_deal_data.items():
messages = self.formatter.get_llm_prompt(features)
self.samples.append((id, messages, 0))
random.shuffle(self.samples)
print(f"total samples num: {len(self.samples)}, deal num: {len(self.deal_data)}, not deal num: {len(self.not_deal_data)}")
def _balance_samples(self):
random.seed(42)
np.random.seed(42)
not_deal_ids = list(self.not_deal_data.keys())
target_size = len(self.deal_data)
if len(not_deal_ids) > target_size:
selected_not_deal_ids = random.sample(not_deal_ids, target_size)
self.not_deal_data = {sid: self.not_deal_data[sid] for sid in selected_not_deal_ids}
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
id, prompt, label = self.samples[idx]
return id, prompt, label
def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size):
dataset = TransDataset(deal_data_folder, not_deal_data_folder)
num_data = len(dataset)
train_size = int(0.8 * num_data)
val_size = int(0.1 * num_data)
test_size = num_data - train_size - val_size
print(f"train size: {train_size}")
print(f"val size: {val_size}")
print(f"test size: {test_size}")
train_dataset, val_dataset, test_dataset = random_split(
dataset,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(42)
)
def collate_fn(batch):
ids = [item[0] for item in batch]
texts = [item[1] for item in batch]
labels = torch.tensor([item[2] for item in batch], dtype=torch.long)
return ids, texts, labels
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
return {"train": train_loader, "val": val_loader, "test": test_loader}