Update data_process/process/preprocess.py
This commit is contained in:
@@ -12,14 +12,14 @@ valid_keys = [
|
||||
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
||||
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
||||
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Last_Interaction", "Referral_Potential"
|
||||
]
|
||||
ch_valid_keys = [
|
||||
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
||||
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
||||
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
||||
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
||||
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
||||
"竞争者心态", "认知阶段", "最后互动时间", "推荐潜力"
|
||||
]
|
||||
all_keys = valid_keys + ["session_id", "label"]
|
||||
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
||||
@@ -27,7 +27,7 @@ d1_keys = valid_keys[:5]
|
||||
d2_keys = valid_keys[5:10]
|
||||
d3_keys = valid_keys[10:15]
|
||||
d4_keys = valid_keys[15:19]
|
||||
d5_keys = valid_keys[19:23]
|
||||
d5_keys = valid_keys[19:22]
|
||||
|
||||
class Formatter:
|
||||
def __init__(self, en2ch):
|
||||
@@ -39,42 +39,49 @@ class Formatter:
|
||||
sections.append("\n [痛感和焦虑等级]")
|
||||
for key in d1_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
|
||||
sections.append("\n [支付意愿与能力]")
|
||||
for key in d2_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
|
||||
sections.append("\n [成交阻力与防御机制]")
|
||||
for key in d3_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
|
||||
sections.append("\n [情绪钩子与成交切入点]")
|
||||
for key in d4_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
|
||||
sections.append("\n [客户生命周期状态]")
|
||||
for key in d5_keys:
|
||||
if key in profile:
|
||||
if profile[key] is None:
|
||||
continue
|
||||
sections.append(f"{self.en2ch[key]}: {profile[key]}")
|
||||
return "\n".join(sections)
|
||||
|
||||
def get_llm_prompt(self, features):
|
||||
def get_llm_prompt(self, features: dict) -> list:
|
||||
user_profile = self._build_user_profile(features)
|
||||
|
||||
prompt = f"""
|
||||
你是一个销售心理学专家,请分析以下客户特征:
|
||||
prompt = prompt = f"""
|
||||
请分析以下客户特征,预测成交概率(0~1之间)。
|
||||
|
||||
{user_profile}
|
||||
{user_profile}
|
||||
|
||||
请提取客户的核心购买驱动力和主要障碍后分析该客户的成交概率。将成交概率以JSON格式输出:
|
||||
{{
|
||||
"conversion_probability": 0-1之间的数值
|
||||
}}
|
||||
"""
|
||||
成交概率:
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
@@ -82,9 +89,9 @@ class Formatter:
|
||||
return messages
|
||||
|
||||
class TransDataset(Dataset):
|
||||
def __init__(self, deal_data_folder, not_deal_data_folder):
|
||||
self.deal_data = extract_json_data(extract_json_files(deal_data_folder))
|
||||
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder))
|
||||
def __init__(self, deal_data_folder, not_deal_data_folder, threshold: int = 10, balance: bool = True):
|
||||
self.deal_data = extract_json_data(extract_json_files(deal_data_folder), threshold)
|
||||
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder), threshold)
|
||||
|
||||
self.formatter = Formatter(en2ch)
|
||||
|
||||
@@ -92,7 +99,7 @@ class TransDataset(Dataset):
|
||||
num_not_deal = len(self.not_deal_data)
|
||||
num_threshold = max(num_deal, num_not_deal) * 0.8
|
||||
|
||||
if not all([num_deal >= num_threshold, num_not_deal >= num_threshold]):
|
||||
if not all([num_deal >= num_threshold, num_not_deal >= num_threshold]) and balance:
|
||||
self._balance_samples()
|
||||
|
||||
self._build_samples()
|
||||
@@ -128,8 +135,30 @@ class TransDataset(Dataset):
|
||||
id, prompt, label = self.samples[idx]
|
||||
return id, prompt, label
|
||||
|
||||
def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size):
|
||||
dataset = TransDataset(deal_data_folder, not_deal_data_folder)
|
||||
class OfflineTransDataset(Dataset):
|
||||
def __init__(self, deal_data_folder, not_deal_data_folder, threshold: int = 10):
|
||||
self.deal_data = extract_json_data(extract_json_files(deal_data_folder), threshold)
|
||||
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder), threshold)
|
||||
|
||||
self.formatter = Formatter(en2ch)
|
||||
self.samples = []
|
||||
|
||||
for id, features in self.deal_data.items():
|
||||
messages = self.formatter.get_llm_prompt(features)
|
||||
self.samples.append((id, messages, 1))
|
||||
for id, features in self.not_deal_data.items():
|
||||
messages = self.formatter.get_llm_prompt(features)
|
||||
self.samples.append((id, messages, 0))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
id, prompt, label = self.samples[idx]
|
||||
return id, prompt, label
|
||||
|
||||
def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold: int = 10, balance: bool = True):
|
||||
dataset = TransDataset(deal_data_folder, not_deal_data_folder, threshold, balance)
|
||||
num_data = len(dataset)
|
||||
|
||||
train_size = int(0.8 * num_data)
|
||||
@@ -170,3 +199,20 @@ def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size):
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return {"train": train_loader, "val": val_loader, "test": test_loader}
|
||||
|
||||
def build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold: int = 10):
|
||||
dataset = OfflineTransDataset(deal_data_folder, not_deal_data_folder, threshold)
|
||||
|
||||
def collate_fn(batch):
|
||||
ids = [item[0] for item in batch]
|
||||
texts = [item[1] for item in batch]
|
||||
labels = torch.tensor([item[2] for item in batch], dtype=torch.long)
|
||||
return ids, texts, labels
|
||||
|
||||
offline_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return offline_loader
|
||||
Reference in New Issue
Block a user