Files
deal-classification/data_process/process/preprocess.py

218 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from transformers import AutoTokenizer
from .content_extract import extract_json_files, extract_json_data
valid_keys = [
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
"Competitor_Mindset", "Cognitive_Stage", "Last_Interaction", "Referral_Potential"
]
ch_valid_keys = [
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
"竞争者心态", "认知阶段", "最后互动时间", "推荐潜力"
]
all_keys = valid_keys + ["session_id", "label"]
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
d1_keys = valid_keys[:5]
d2_keys = valid_keys[5:10]
d3_keys = valid_keys[10:15]
d4_keys = valid_keys[15:19]
d5_keys = valid_keys[19:22]
class Formatter:
def __init__(self, en2ch):
self.en2ch = en2ch
def _build_user_profile(self, profile: dict) -> str:
sections = []
sections.append("[客户画像]")
sections.append("\n [痛感和焦虑等级]")
for key in d1_keys:
if key in profile:
if profile[key] is None:
continue
sections.append(f"{self.en2ch[key]}: {profile[key]}")
sections.append("\n [支付意愿与能力]")
for key in d2_keys:
if key in profile:
if profile[key] is None:
continue
sections.append(f"{self.en2ch[key]}: {profile[key]}")
sections.append("\n [成交阻力与防御机制]")
for key in d3_keys:
if key in profile:
if profile[key] is None:
continue
sections.append(f"{self.en2ch[key]}: {profile[key]}")
sections.append("\n [情绪钩子与成交切入点]")
for key in d4_keys:
if key in profile:
if profile[key] is None:
continue
sections.append(f"{self.en2ch[key]}: {profile[key]}")
sections.append("\n [客户生命周期状态]")
for key in d5_keys:
if key in profile:
if profile[key] is None:
continue
sections.append(f"{self.en2ch[key]}: {profile[key]}")
return "\n".join(sections)
def get_llm_prompt(self, features: dict) -> list:
user_profile = self._build_user_profile(features)
prompt = prompt = f"""
请分析以下客户特征预测成交概率0~1之间
{user_profile}
成交概率:
"""
messages = [
{"role": "user", "content": prompt}
]
return messages
class TransDataset(Dataset):
def __init__(self, deal_data_folder, not_deal_data_folder, threshold: int = 10, balance: bool = True):
self.deal_data = extract_json_data(extract_json_files(deal_data_folder), threshold)
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder), threshold)
self.formatter = Formatter(en2ch)
num_deal = len(self.deal_data)
num_not_deal = len(self.not_deal_data)
num_threshold = max(num_deal, num_not_deal) * 0.8
if not all([num_deal >= num_threshold, num_not_deal >= num_threshold]) and balance:
self._balance_samples()
self._build_samples()
def _build_samples(self):
self.samples = []
for id, features in self.deal_data.items():
messages = self.formatter.get_llm_prompt(features)
self.samples.append((id, messages, 1))
for id, features in self.not_deal_data.items():
messages = self.formatter.get_llm_prompt(features)
self.samples.append((id, messages, 0))
random.shuffle(self.samples)
print(f"total samples num: {len(self.samples)}, deal num: {len(self.deal_data)}, not deal num: {len(self.not_deal_data)}")
def _balance_samples(self):
random.seed(42)
np.random.seed(42)
not_deal_ids = list(self.not_deal_data.keys())
target_size = len(self.deal_data)
if len(not_deal_ids) > target_size:
selected_not_deal_ids = random.sample(not_deal_ids, target_size)
self.not_deal_data = {sid: self.not_deal_data[sid] for sid in selected_not_deal_ids}
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
id, prompt, label = self.samples[idx]
return id, prompt, label
class OfflineTransDataset(Dataset):
def __init__(self, deal_data_folder, not_deal_data_folder, threshold: int = 10):
self.deal_data = extract_json_data(extract_json_files(deal_data_folder), threshold)
self.not_deal_data = extract_json_data(extract_json_files(not_deal_data_folder), threshold)
self.formatter = Formatter(en2ch)
self.samples = []
for id, features in self.deal_data.items():
messages = self.formatter.get_llm_prompt(features)
self.samples.append((id, messages, 1))
for id, features in self.not_deal_data.items():
messages = self.formatter.get_llm_prompt(features)
self.samples.append((id, messages, 0))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
id, prompt, label = self.samples[idx]
return id, prompt, label
def build_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold: int = 10, balance: bool = True):
dataset = TransDataset(deal_data_folder, not_deal_data_folder, threshold, balance)
num_data = len(dataset)
train_size = int(0.8 * num_data)
val_size = int(0.1 * num_data)
test_size = num_data - train_size - val_size
print(f"train size: {train_size}")
print(f"val size: {val_size}")
print(f"test size: {test_size}")
train_dataset, val_dataset, test_dataset = random_split(
dataset,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(42)
)
def collate_fn(batch):
ids = [item[0] for item in batch]
texts = [item[1] for item in batch]
labels = torch.tensor([item[2] for item in batch], dtype=torch.long)
return ids, texts, labels
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
return {"train": train_loader, "val": val_loader, "test": test_loader}
def build_offline_dataloader(deal_data_folder, not_deal_data_folder, batch_size, threshold: int = 10):
dataset = OfflineTransDataset(deal_data_folder, not_deal_data_folder, threshold)
def collate_fn(batch):
ids = [item[0] for item in batch]
texts = [item[1] for item in batch]
labels = torch.tensor([item[2] for item in batch], dtype=torch.long)
return ids, texts, labels
offline_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
return offline_loader