Compare commits
7 Commits
ea34a09c8f
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| fed16f04e1 | |||
| 06bcaad8d4 | |||
| 0b168057ca | |||
| 1bd4547b99 | |||
| 3d78b88d47 | |||
| 3dd90cc50e | |||
| 7ea59f740f |
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
Qwen3-1.7B/**
|
||||||
|
**/__pycache__/**
|
||||||
|
|
||||||
|
.venv
|
||||||
|
.env
|
||||||
|
best_ckpt.pth
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.13
|
||||||
452
feature_extraction.py
Normal file
452
feature_extraction.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
'''
|
||||||
|
批量提取特征方法
|
||||||
|
很花时间
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv(dotenv_path=".env")
|
||||||
|
|
||||||
|
# 读取环境变量
|
||||||
|
API_KEY = os.getenv("QWEN_API_KEY")
|
||||||
|
MODEL = os.getenv("MODEL_NAME", "qwen3-next-80b-a3b-thinking")
|
||||||
|
TEMPERATURE = float(os.getenv("TEMPERATURE", 0.1)) # 降低随机性,提升格式稳定性
|
||||||
|
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "time_12_1/data_ch_1")
|
||||||
|
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=API_KEY,
|
||||||
|
# base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
# https://sg1.proxy.yinlihupo.cc/proxy/
|
||||||
|
base_url="https://openrouter.ai/api/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_extraction_prompt(dialogue_text):
|
||||||
|
prompt_template = """
|
||||||
|
# Role
|
||||||
|
你是一名拥有15年经验的“家庭教育销售通话审计专家”。你的核心能力是透过家长杂乱的表述,精准捕捉深层心理动机、家庭权力结构、隐形财富信号以及高危销售线索。
|
||||||
|
|
||||||
|
# Core Protocols (核心审计协议 - 最高优先级)
|
||||||
|
|
||||||
|
## 1. 宁缺毋滥原则 (The Principle of Precision)
|
||||||
|
* **存在即输出,无证即沉默**:忽略任何关于“字段数量”的限制。如果原文中有20个维度的有效证据,就输出20个;如果只有3个,就输出3个。
|
||||||
|
* **严禁凑数**:如果原文未提及某维度,或者证据模糊两可,**绝对不要**输出该 Key。不要为了追求“信息丰富”而强行填空。
|
||||||
|
|
||||||
|
## 2. 证据阵列法则 (The Law of Evidence Arrays)
|
||||||
|
* **数据结构变更**:`evidence` 字段必须是 **字符串数组 (List<String>)**,严禁使用单一字符串。
|
||||||
|
* **颗粒度控制 (Granularity Control)**:
|
||||||
|
* 数组中的元素必须是 **具有独立语义的完整原句** 或 **包含主谓宾的完整意群**。
|
||||||
|
* **禁止碎片**:严禁提取如 "不合适"、"太贵"、"焦虑" 这样缺乏上下文的短语。
|
||||||
|
* **主体过滤**:**仅提取家长(客户)表达的原话**,严禁提取销售人员的引导语、复述语或共情语。
|
||||||
|
* **纯净引用**:每一个元素必须是原文的 100% 完美复制。
|
||||||
|
* **严禁拼接**:严禁使用“+”、“和”、“以及”将两句不连贯的话拼在同一个字符串里。
|
||||||
|
* **严禁篡改**:禁止总结、禁止润色、禁止“原文+分析”。你的分析只能体现在 `value` 字段中。
|
||||||
|
|
||||||
|
## 3. 结论极简法则 (The Law of Concise Conclusion)
|
||||||
|
* **强制必输字段**:`Follow_up_Priority` 是 **核心必选字段**,无论任何情况都必须输出,不允许缺失。
|
||||||
|
* **Follow_up_Priority 兜底规则**:
|
||||||
|
- 若文本完全无痛点/无财力/无意识,`value` 填“C级 (无痛点/无意识)”,`evidence` 填 ["文本未提及任何痛点、财力或意向相关内容"]
|
||||||
|
- 若仅部分信息缺失,按规则评级并在 `evidence` 中列出已有有效原句。
|
||||||
|
* **Value 约束**:`value` 字段必须是 **客观、简练的定性结论**(必须限制在 **20个汉字以内**)。
|
||||||
|
* *正确示例*: "A级 (高痛点+强财力)"
|
||||||
|
* *错误示例*: "家长表现出对价格的犹豫,虽然她很有钱,但是因为..." (禁止小作文)
|
||||||
|
|
||||||
|
## 4. 身份与财富的高敏嗅觉
|
||||||
|
* 对于**高价值信号**(职业/多孩/私立学校/房产)和**生命红线**(自杀/不想活了/抑郁症确诊)保持极度敏感,一旦出现必须提取。
|
||||||
|
|
||||||
|
# Task
|
||||||
|
阅读提供的销售通话录音文本,从以下 23 个预设维度中筛选出**有效信息**,生成一份高精度的客户画像 JSON。
|
||||||
|
|
||||||
|
# Field Definitions (字段定义与提取逻辑)
|
||||||
|
### [第一组:心理动力与危机]
|
||||||
|
1. **Core_Fear_Source** (深层恐惧)
|
||||||
|
* *逻辑*: 驱动家长寻求帮助的终极噩梦。是怕孩子死(生命安全)?怕孩子阶级跌落?还是怕自己面子挂不住?
|
||||||
|
* *注意*: 必须提取具体的后果描述。
|
||||||
|
2. **Pain_Threshold** (痛苦阈值)
|
||||||
|
* *逻辑*: 家长当前的情绪状态。是“崩溃急救”(无法忍受,必须马上解决),还是“隐隐作痛”(还能凑合)?
|
||||||
|
3. **Time_Window_Pressure** (时间压力)
|
||||||
|
* *逻辑*: 客观的截止日期。如:距离中高考仅剩X月、休学复课最后期限、学校劝退通牒。
|
||||||
|
4. **Helplessness_Index** (无助指数)
|
||||||
|
* *逻辑*: 家长是否已经尝试过多种方法均失败(习得性无助),还是盲目自信觉得还能管。
|
||||||
|
5. **Social_Shame** (社交耻感)
|
||||||
|
* *逻辑*: 孩子问题是否影响了家长的社会形象(怕老师找、怕亲戚问、不敢出门)。
|
||||||
|
6. **Ultimatum_Event** (爆发事件)
|
||||||
|
* *逻辑*: 迫使家长此时此刻咨询的导火索。如:昨日发生的激烈争吵、离家出走、打架、学校停课通知。
|
||||||
|
7. **Emotional_Trigger** (情绪扳机)
|
||||||
|
* *逻辑*: 沟通中家长情绪最激动的点(哭泣、愤怒、颤抖)。
|
||||||
|
|
||||||
|
### [第二组:阻力与障碍]
|
||||||
|
8. **Secret_Resistance** (隐性抗拒)
|
||||||
|
* *逻辑*: **阻碍成交**的心理障碍。特指:怕被家人知道买课、怕孩子知道家长在咨询、觉得课程是骗局。
|
||||||
|
* *排除*: 孩子的生活秘密(如抽烟/早恋)不属于此字段。
|
||||||
|
9. **Trust_Deficit** (信任赤字)
|
||||||
|
* *逻辑*: 对机构/销售/网课模式的直接质疑。如:“你们正规吗?”“之前被骗过”。
|
||||||
|
10. **Family_Sabotage** (家庭阻力)
|
||||||
|
* *逻辑*: 家庭中明确的反对者或捣乱者(拆台的配偶、干涉的长辈、发病的家属)。
|
||||||
|
* *排除*: 客观的不幸(如家人生病/车祸)不属于此字段,除非该事件直接阻碍了家长听课。
|
||||||
|
11. **Low_Self_Efficacy** (效能感低)
|
||||||
|
* *逻辑*: 家长担心**自己**学不会、坚持不下来、没时间听课。
|
||||||
|
12. **Attribution_Barrier** (归因偏差)
|
||||||
|
* *逻辑*: 家长认为错在谁?(全是学校的错 / 全是手机的错 / 全是遗传的错 / 承认自己有错)。
|
||||||
|
|
||||||
|
### [第三组:资源与决策]
|
||||||
|
13. **Payer_Decision_Maker** (决策权)
|
||||||
|
* *逻辑*: 谁掌握财权?谁有一票否决权?是“妈妈独裁”还是“需商量”?
|
||||||
|
14. **Hidden_Wealth_Proof** (隐形财力)
|
||||||
|
* *逻辑*: 寻找高消费证据。如:私立学校、出国计划、高昂学费、住别墅、高知职业(教授/医生)。
|
||||||
|
15. **Price_Sensitivity** (价格敏感度)
|
||||||
|
* *逻辑*: 对价格的反应。是“只看效果不差钱”,还是“犹豫比价”、“哭穷”。
|
||||||
|
16. **Sunk_Cost** (沉没成本)
|
||||||
|
* *逻辑*: 过往已投入的无效成本。如:之前报过xx辅导班、做过xx次心理咨询、花了xx万没效果。
|
||||||
|
17. **Compensatory_Spending** (补偿心理)
|
||||||
|
* *逻辑*: 是否因亏欠感而通过花钱(买东西/报课)来弥补孩子。
|
||||||
|
|
||||||
|
### [第四组:销售价值判断]
|
||||||
|
18. **Expectation_Bonus** (期望范围)
|
||||||
|
* *逻辑*: 家长的底线(只要活着/不退学)与理想(考大学/变优秀)。
|
||||||
|
19. **Competitor_Mindset** (竞品思维)
|
||||||
|
* *逻辑*: 家长是否在对比其他**解决方案**。如:特训学校、心理医生(针对孩子)、线下辅导班。
|
||||||
|
* *排除*: 家属的就医经历不属于此字段。
|
||||||
|
20. **Cognitive_Stage** (认知阶段)
|
||||||
|
* *逻辑*: 愚昧期(修孩子) -> 觉醒期(修自己/找方法)。
|
||||||
|
21. **Referral_Potential** (转介绍潜力)
|
||||||
|
* *逻辑*: 基于身份判断。重点捕捉:多孩家庭、教师/医生/教授/公务员身份(KOL潜质)、家长委员会成员。
|
||||||
|
22. **Last_Interaction** (互动状态)
|
||||||
|
* *逻辑*: 通话结束时的温度。秒回/挂断/索要案例/已读不回。
|
||||||
|
23. **Follow_up_Priority** (跟进优先级) - [重点监控字段]
|
||||||
|
* *逻辑*: 综合评级(S/A/B/C)。
|
||||||
|
* **Extraction Rule (必须使用数组逻辑)**:
|
||||||
|
* 如果评级为 **S/A**(通常需要痛点+财力/意向双重支撑),必须在 `evidence` 数组中分别列出这两方面(甚至三方面)的原话。
|
||||||
|
* **S级**: 涉及生命安全 OR (极高痛点 + 强支付能力 + 强意向)。
|
||||||
|
* **A级**: 有痛点 + 有支付能力。
|
||||||
|
* **B级**: 有痛点 + 无支付能力/犹豫。
|
||||||
|
* **C级**: 无痛点/无意识。
|
||||||
|
|
||||||
|
# Output Format (输出格式指令)
|
||||||
|
**强制要求1**:JSON 中必须包含 `Follow_up_Priority` 字段,否则视为无效输出。
|
||||||
|
**强制要求2**:JSON 格式必须严格合法(逗号分隔、引号成对、括号匹配),可直接被JSON解析工具识别。
|
||||||
|
**强制要求3**:直接输出 JSON 对象,无需 Markdown 代码块、解释性文字或额外内容。
|
||||||
|
**强制要求4**:若文中未提及除了Follow_up_Priority字段的22个字段内容,就不要输出。
|
||||||
|
**强制要求5**:不准将示例模板内容作为证据输出,必须从原文中找证据。
|
||||||
|
|
||||||
|
JSON 结构要求:
|
||||||
|
1. **Key**: 仅使用上述定义中出现的英文 Key。
|
||||||
|
2. **Value**: 必须是 **<20字** 的短语结论。
|
||||||
|
3. **Evidence**: 必须是 **List<String>** (字符串数组)。
|
||||||
|
4. **Strict Validation (自我审查)**:
|
||||||
|
* 检查 `evidence` 是否包含“家长说”、“意思就是”? -> 若有,**改为纯引用**。
|
||||||
|
* 检查 `evidence` 是否包含销售说的话? -> 若有,**删除该元素**。
|
||||||
|
* 检查 JSON 语法是否正确? -> 确保逗号不遗漏、括号成对。
|
||||||
|
* 不准将示例模板内容作为证据输出,必须从原文中找证据。
|
||||||
|
|
||||||
|
**Example Output:**
|
||||||
|
{
|
||||||
|
"Follow_up_Priority": {
|
||||||
|
"value": "A级 (痛点强+财力足)",
|
||||||
|
"evidence": [
|
||||||
|
"我是今年才确诊,他是焦虑的", // 证据1:完整原句支撑痛点
|
||||||
|
"孩子现在在西工大附中上学", // 证据2:完整原句支撑隐形财力
|
||||||
|
"留学基金我们已经准备好了" // 证据3:完整原句支撑支付能力
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"Pain_Threshold": {
|
||||||
|
"value": "崩溃急救状态",
|
||||||
|
"evidence": [
|
||||||
|
"我不知道怎么来处理",
|
||||||
|
"一看见难了就崩溃啊就崩溃"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"Trust_Deficit": {
|
||||||
|
"value": "质疑课程通用性",
|
||||||
|
"evidence": [
|
||||||
|
"你这些课不都是通用的吗?"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}"""
|
||||||
|
full_prompt = f"{prompt_template}\n\n### 原始通话文本\n{dialogue_text}\n\n### 请严格按照上述要求输出JSON(仅JSON,无其他内容)"
|
||||||
|
return full_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def clean_and_fix_json(json_str):
|
||||||
|
"""清洗JSON格式,不补充兜底内容"""
|
||||||
|
try:
|
||||||
|
# 移除转义符、控制字符和多余空格
|
||||||
|
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
|
||||||
|
json_str = re.sub(r'[\n\r\t\f\v]', '', json_str)
|
||||||
|
json_str = re.sub(r'\s+', ' ', json_str).strip()
|
||||||
|
# 修复末尾多余逗号
|
||||||
|
json_str = re.sub(r",\s*}", "}", json_str)
|
||||||
|
json_str = re.sub(r",\s*]", "]", json_str)
|
||||||
|
return json_str
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"JSON清洗失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features_with_qwen(dialogue_text, file_name, output_dir="qwen_new_123"):
|
||||||
|
"""
|
||||||
|
调用API提取特征并保存为JSON文件
|
||||||
|
:param dialogue_text: 预处理后的对话文本(字符串)
|
||||||
|
:param file_name: 原文件名称(用于生成输出文件名)
|
||||||
|
:param output_dir: 结果保存目录
|
||||||
|
:return: 提取的特征字典(失败则返回None)
|
||||||
|
"""
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
prompt = build_extraction_prompt(dialogue_text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
max_tokens=8000
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_json_str = response.choices[0].message.content.strip()
|
||||||
|
# 提取JSON片段
|
||||||
|
json_match = re.search(r"\{[\s\S]*\}", feature_json_str)
|
||||||
|
if json_match:
|
||||||
|
feature_json_str = json_match.group()
|
||||||
|
else:
|
||||||
|
raise ValueError("返回内容中未找到有效JSON数据")
|
||||||
|
|
||||||
|
# 移除代码块标记
|
||||||
|
if feature_json_str.startswith("```json"):
|
||||||
|
feature_json_str = feature_json_str[7:-3].strip()
|
||||||
|
elif feature_json_str.startswith("```"):
|
||||||
|
feature_json_str = feature_json_str[3:-3].strip()
|
||||||
|
|
||||||
|
feature_dict = json.loads(feature_json_str)
|
||||||
|
|
||||||
|
# 验证核心字段
|
||||||
|
if "Follow_up_Priority" not in feature_dict:
|
||||||
|
raise ValueError("返回结果缺失核心必选字段:Follow_up_Priority")
|
||||||
|
|
||||||
|
# 生成输出文件名
|
||||||
|
file_base = os.path.splitext(file_name)[0]
|
||||||
|
json_filename = f"{file_base}.json"
|
||||||
|
output_path = os.path.join(output_dir, json_filename)
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(feature_dict, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"处理完成:{file_name} -> {json_filename}")
|
||||||
|
return feature_dict
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"JSON解析失败 {file_name}:{str(e)} | 原始内容:{feature_json_str[:200]}...")
|
||||||
|
return None
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"数据验证失败 {file_name}:{str(e)}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"特征提取失败 {file_name}:{str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 批量处理函数
|
||||||
|
def batch_process_first_200_txt(folder_path, output_dir):
|
||||||
|
"""
|
||||||
|
仅处理指定文件夹下的前200个txt文件
|
||||||
|
:param folder_path: 待处理文件夹路径
|
||||||
|
:param output_dir: 结果输出目录
|
||||||
|
"""
|
||||||
|
# 检查文件夹是否存在
|
||||||
|
if not os.path.isdir(folder_path):
|
||||||
|
print(f"文件夹不存在:{folder_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 筛选出文件夹中的txt文件并按名称排序(保证处理顺序稳定)
|
||||||
|
txt_file_list = [
|
||||||
|
f for f in os.listdir(folder_path)
|
||||||
|
if os.path.isfile(os.path.join(folder_path, f)) and f.lower().endswith(".txt")
|
||||||
|
]
|
||||||
|
# 按文件名排序(可选,保证每次处理顺序一致)
|
||||||
|
txt_file_list.sort()
|
||||||
|
|
||||||
|
# 取前200个txt文件
|
||||||
|
target_files = txt_file_list[:200]
|
||||||
|
|
||||||
|
if not target_files:
|
||||||
|
print(f"文件夹 {folder_path} 中无txt文件可处理")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"始处理前 {len(target_files)} 个txt文件")
|
||||||
|
|
||||||
|
processed_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
for file_name in target_files:
|
||||||
|
file_path = os.path.join(folder_path, file_name)
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
dialogue_content = f.read().strip()
|
||||||
|
if not dialogue_content:
|
||||||
|
print(f"文件内容为空,跳过:{file_name}")
|
||||||
|
failed_count += 1
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取文件失败 {file_name}:{str(e)}")
|
||||||
|
failed_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 调用特征提取函数
|
||||||
|
result = extract_features_with_qwen(dialogue_content, file_name, output_dir)
|
||||||
|
if result:
|
||||||
|
processed_count += 1
|
||||||
|
else:
|
||||||
|
failed_count += 1
|
||||||
|
|
||||||
|
# 输出批量处理统计结果
|
||||||
|
print("\n批量处理完成")
|
||||||
|
print(f"成功处理:{processed_count} 个文件")
|
||||||
|
print(f"处理失败:{failed_count} 个文件")
|
||||||
|
print(f"结果保存至:{os.path.abspath(output_dir)}")
|
||||||
|
|
||||||
|
|
||||||
|
def process_single_txt(file_path, output_dir=OUTPUT_DIR):
|
||||||
|
"""
|
||||||
|
处理单个TXT文件,提取特征并保存JSON
|
||||||
|
:param file_path: 单个TXT文件的完整路径
|
||||||
|
:param output_dir: JSON结果保存目录
|
||||||
|
"""
|
||||||
|
# 1. 验证文件是否存在且是TXT文件
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"文件不存在:{file_path}")
|
||||||
|
if not file_path.lower().endswith(".txt"):
|
||||||
|
raise ValueError(f"不是TXT文件:{file_path}")
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
raise IsADirectoryError(f"这是文件夹,不是文件:{file_path}")
|
||||||
|
|
||||||
|
# 2. 读取TXT文件内容
|
||||||
|
print(f"正在读取文件:{file_path}")
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
dialogue_content = f.read().strip()
|
||||||
|
if not dialogue_content:
|
||||||
|
raise ValueError("文件内容为空")
|
||||||
|
print(f"成功读取文件(字符数:{len(dialogue_content)})")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"读取文件失败:{str(e)}") from e
|
||||||
|
|
||||||
|
# 3. 构建提示词并调用API
|
||||||
|
prompt = build_extraction_prompt(dialogue_content)
|
||||||
|
try:
|
||||||
|
print("正在调用API提取特征...")
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
max_tokens=8000,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"API调用失败:{str(e)}") from e
|
||||||
|
|
||||||
|
# 4. 提取并清洗JSON
|
||||||
|
feature_json_str = response.choices[0].message.content.strip()
|
||||||
|
json_match = re.search(r"\{[\s\S]*\}", feature_json_str)
|
||||||
|
if not json_match:
|
||||||
|
raise RuntimeError(f"API返回无有效JSON:{feature_json_str[:200]}...")
|
||||||
|
cleaned_json = clean_and_fix_json(json_match.group())
|
||||||
|
|
||||||
|
# 5. 解析并验证JSON
|
||||||
|
try:
|
||||||
|
parsed_dict = json.loads(cleaned_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise RuntimeError(f"JSON解析失败:{str(e)} | 清洗后内容:{cleaned_json[:500]}") from e
|
||||||
|
|
||||||
|
# 验证核心字段
|
||||||
|
if "Follow_up_Priority" not in parsed_dict:
|
||||||
|
raise RuntimeError("核心字段Follow_up_Priority缺失")
|
||||||
|
fu_prio = parsed_dict["Follow_up_Priority"]
|
||||||
|
if not isinstance(fu_prio, dict) or "value" not in fu_prio or "evidence" not in fu_prio:
|
||||||
|
raise RuntimeError("Follow_up_Priority格式错误(需包含value和evidence)")
|
||||||
|
if not isinstance(fu_prio["evidence"], list):
|
||||||
|
raise RuntimeError("evidence必须是数组类型")
|
||||||
|
if len(str(fu_prio["value"])) >= 20:
|
||||||
|
raise RuntimeError(f"value超20字限制:{fu_prio['value']}")
|
||||||
|
|
||||||
|
# 6. 保存JSON结果
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
json_file_name = f"{os.path.splitext(file_name)[0]}"
|
||||||
|
json_save_path = os.path.join(output_dir, json_file_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_save_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(parsed_dict, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"处理完成!JSON保存至:{json_save_path}")
|
||||||
|
return parsed_dict
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"保存JSON失败:{str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def process_single(content : str):
|
||||||
|
# 3. 构建提示词并调用API
|
||||||
|
prompt = build_extraction_prompt(content)
|
||||||
|
try:
|
||||||
|
print("正在调用API提取特征...")
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
max_tokens=8000,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"API调用失败:{str(e)}") from e
|
||||||
|
|
||||||
|
# 4. 提取并清洗JSON
|
||||||
|
feature_json_str = response.choices[0].message.content.strip()
|
||||||
|
json_match = re.search(r"\{[\s\S]*\}", feature_json_str)
|
||||||
|
if not json_match:
|
||||||
|
raise RuntimeError(f"API返回无有效JSON:{feature_json_str[:200]}...")
|
||||||
|
cleaned_json = clean_and_fix_json(json_match.group())
|
||||||
|
|
||||||
|
# 5. 解析并验证JSON
|
||||||
|
try:
|
||||||
|
parsed_dict = json.loads(cleaned_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise RuntimeError(f"JSON解析失败:{str(e)} | 清洗后内容:{cleaned_json[:500]}") from e
|
||||||
|
|
||||||
|
# 验证核心字段
|
||||||
|
if "Follow_up_Priority" not in parsed_dict:
|
||||||
|
raise RuntimeError("核心字段Follow_up_Priority缺失")
|
||||||
|
fu_prio = parsed_dict["Follow_up_Priority"]
|
||||||
|
if not isinstance(fu_prio, dict) or "value" not in fu_prio or "evidence" not in fu_prio:
|
||||||
|
raise RuntimeError("Follow_up_Priority格式错误(需包含value和evidence)")
|
||||||
|
if not isinstance(fu_prio["evidence"], list):
|
||||||
|
raise RuntimeError("evidence必须是数组类型")
|
||||||
|
if len(str(fu_prio["value"])) >= 20:
|
||||||
|
raise RuntimeError(f"value超20字限制:{fu_prio['value']}")
|
||||||
|
|
||||||
|
return parsed_dict
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 要处理的源文件夹路径
|
||||||
|
target_folder = "./time_12_1/data_200"
|
||||||
|
|
||||||
|
# 执行批量处理:仅处理前200个txt文件,输出到
|
||||||
|
batch_process_first_200_txt(
|
||||||
|
folder_path=target_folder,
|
||||||
|
output_dir=OUTPUT_DIR
|
||||||
|
)
|
||||||
|
# SINGLE_TXT_PATH = "./qwen/cdb7d561-975a-431e-86d4-9b3ddc714f73.txt"
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# # 执行单个文件处理
|
||||||
|
# process_single_txt(file_path=SINGLE_TXT_PATH)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"\n处理失败:{str(e)}")
|
||||||
|
|
||||||
364
inference.py
364
inference.py
@@ -1,201 +1,163 @@
|
|||||||
from model import TransClassifier
|
from model import TransClassifier
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from data_process import extract_json_data, Formatter
|
from data_process import extract_json_data, Formatter
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
valid_keys = [
|
valid_keys = [
|
||||||
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
|
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
|
||||||
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
||||||
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
||||||
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
||||||
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
||||||
]
|
]
|
||||||
ch_valid_keys = [
|
ch_valid_keys = [
|
||||||
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
||||||
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
||||||
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
||||||
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
||||||
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
||||||
]
|
]
|
||||||
all_keys = valid_keys + ["session_id", "label"]
|
all_keys = valid_keys + ["session_id", "label"]
|
||||||
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
||||||
d1_keys = valid_keys[:5]
|
d1_keys = valid_keys[:5]
|
||||||
d2_keys = valid_keys[5:10]
|
d2_keys = valid_keys[5:10]
|
||||||
d3_keys = valid_keys[10:15]
|
d3_keys = valid_keys[10:15]
|
||||||
d4_keys = valid_keys[15:19]
|
d4_keys = valid_keys[15:19]
|
||||||
d5_keys = valid_keys[19:23]
|
d5_keys = valid_keys[19:23]
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
|
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
|
||||||
self.backbone_dir = backbone_dir
|
self.backbone_dir = backbone_dir
|
||||||
self.ckpt_path = ckpt_path
|
self.ckpt_path = ckpt_path
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
# 加载 tokenizer
|
# 加载 tokenizer
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||||
print(f"Tokenizer loaded from {backbone_dir}")
|
print(f"Tokenizer loaded from {backbone_dir}")
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
self.model = TransClassifier(backbone_dir, device)
|
self.model = TransClassifier(backbone_dir, device)
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
if self.ckpt_path:
|
if self.ckpt_path:
|
||||||
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||||
print(f"Model loaded from {ckpt_path}")
|
print(f"Model loaded from {ckpt_path}")
|
||||||
else:
|
else:
|
||||||
print("Warning: No checkpoint path provided. Using untrained model.")
|
print("Warning: No checkpoint path provided. Using untrained model.")
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
print("Inference engine initialized successfully.")
|
print("Inference engine initialized successfully.")
|
||||||
|
|
||||||
self.formatter = Formatter(en2ch)
|
self.formatter = Formatter(en2ch)
|
||||||
|
|
||||||
def inference_batch(self, json_list: List[str]) -> dict:
|
def inference_batch(self, json_list: List[str]) -> dict:
|
||||||
"""
|
"""
|
||||||
批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
||||||
请注意Json文件中的词条数必须大于等于10.
|
请注意Json文件中的词条数必须大于等于10.
|
||||||
"""
|
"""
|
||||||
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
# print(111111)
|
||||||
id2feature = extract_json_data(json_list) # id2feature
|
assert len(json_list) <= 10, "单次输入json文件数量不可超过8。"
|
||||||
|
id2feature = extract_json_data(json_list)
|
||||||
message_list = []
|
print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
|
||||||
for id, feature in id2feature.items():
|
# id2feature
|
||||||
messages = self.formatter.get_llm_prompt(feature)
|
|
||||||
message_list.append(messages)
|
message_list = []
|
||||||
|
for id, feature in id2feature.items():
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
messages = self.formatter.get_llm_prompt(feature)
|
||||||
message_list,
|
message_list.append(messages)
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
enable_thinking=False
|
message_list,
|
||||||
)
|
tokenize=False,
|
||||||
model_inputs = self.tokenizer(
|
add_generation_prompt=True,
|
||||||
inputs,
|
enable_thinking=False
|
||||||
padding=True,
|
)
|
||||||
truncation=True,
|
model_inputs = self.tokenizer(
|
||||||
max_length=2048,
|
inputs,
|
||||||
return_tensors="pt"
|
padding=True,
|
||||||
).to(self.device)
|
truncation=True,
|
||||||
|
max_length=2048,
|
||||||
with torch.inference_mode():
|
return_tensors="pt"
|
||||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
).to(self.device)
|
||||||
outputs = self.model(model_inputs)
|
|
||||||
|
with torch.inference_mode():
|
||||||
# 1. 计算分类标签(argmax)
|
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
outputs = self.model(model_inputs)
|
||||||
|
|
||||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
# 1. 计算分类标签(argmax)
|
||||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
|
||||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||||
probs = probs.cpu().numpy().tolist()
|
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||||
|
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表
|
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||||
return {"labels": preds, "probs": probs}
|
probs = probs.cpu().numpy().tolist()
|
||||||
|
probs = [p[1] for p in probs] # 只保留类别1的概率
|
||||||
def inference_sample(self, json_path: str) -> dict:
|
|
||||||
"""
|
# 3. 计算置信度
|
||||||
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
|
confidence = [abs(p - 0.5) * 2 for p in probs]
|
||||||
请注意Json文件中的词条数必须大于等于10.
|
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表
|
||||||
"""
|
return {"labels": preds, "probs": probs, "confidence": confidence}
|
||||||
return self.inference_batch([json_path])
|
|
||||||
|
def inference_sample(self, json_path: str) -> dict:
|
||||||
if __name__ == "__main__":
|
"""
|
||||||
# 配置参数
|
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
|
||||||
backbone_dir = "Qwen3-1.7B"
|
请注意Json文件中的词条数必须大于等于10.
|
||||||
ckpt_path = "best_ckpt.pth"
|
"""
|
||||||
device = "cuda"
|
return self.inference_batch([json_path])
|
||||||
|
|
||||||
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
def inference(
|
||||||
|
self,
|
||||||
from data_process import extract_json_files
|
featurs : dict[str ,dict]
|
||||||
import random
|
):
|
||||||
|
assert len(featurs) <= 10, "单次输入json文件数量不可超过8。"
|
||||||
# 获取成交和未成交的json文件路径
|
message_list = []
|
||||||
deal_files = extract_json_files("deal")
|
for id, feature in featurs.items():
|
||||||
not_deal_files = extract_json_files("not_deal")
|
messages = self.formatter.get_llm_prompt(feature)
|
||||||
|
message_list.append(messages)
|
||||||
def filter_json_files_by_key_count(files: List[str], min_keys: int = 10) -> List[str]:
|
|
||||||
"""
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
过滤出JSON文件中字典键数量大于等于指定数量的文件
|
message_list,
|
||||||
|
tokenize=False,
|
||||||
Args:
|
add_generation_prompt=True,
|
||||||
files: JSON文件路径列表
|
enable_thinking=False
|
||||||
min_keys: 最小键数量要求,默认为10
|
)
|
||||||
|
|
||||||
Returns:
|
model_inputs = self.tokenizer(
|
||||||
符合条件的文件路径列表
|
inputs,
|
||||||
"""
|
padding=True,
|
||||||
valid_files = []
|
truncation=True,
|
||||||
|
max_length=2048,
|
||||||
for file_path in files:
|
return_tensors="pt"
|
||||||
try:
|
).to(self.device)
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
with torch.inference_mode():
|
||||||
|
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||||
# 检查是否为字典且键数量是否符合要求
|
outputs = self.model(model_inputs)
|
||||||
if isinstance(data, dict) and len(data) >= min_keys:
|
|
||||||
valid_files.append(file_path)
|
# 1. 计算分类标签(argmax)
|
||||||
else:
|
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||||
print(f"跳过文件 {os.path.basename(file_path)}: 键数量不足 ({len(data)} < {min_keys})")
|
|
||||||
except Exception as e:
|
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||||
print(f"读取文件 {file_path} 时出错: {e}")
|
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||||
|
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||||
return valid_files
|
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||||
|
probs = probs.cpu().numpy().tolist()
|
||||||
deal_files_filtered = filter_json_files_by_key_count(deal_files, min_keys=10)
|
probs = [p[1] for p in probs] # 只保留类别1的概率
|
||||||
not_deal_files_filtered = filter_json_files_by_key_count(not_deal_files, min_keys=10)
|
|
||||||
|
# 3. 计算置信度
|
||||||
num_samples = 8
|
confidence = [abs(p - 0.5) * 2 for p in probs]
|
||||||
|
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表
|
||||||
# 计算每类需要选取的数量
|
return {"labels": preds, "probs": probs, "confidence": confidence}
|
||||||
num_deal_needed = min(4, len(deal_files_filtered)) # 最多选4个成交文件
|
|
||||||
num_not_deal_needed = min(4, len(not_deal_files_filtered)) # 最多选4个未成交文件
|
if __name__ == "__main__":
|
||||||
|
# 配置参数
|
||||||
# 如果某类文件不足,从另一类补足
|
backbone_dir = "Qwen3-1.7B"
|
||||||
if num_deal_needed + num_not_deal_needed < num_samples:
|
ckpt_path = "best_ckpt.pth"
|
||||||
if len(deal_files_filtered) > num_deal_needed:
|
device = "cuda"
|
||||||
num_deal_needed = min(num_samples, len(deal_files_filtered))
|
|
||||||
elif len(not_deal_files_filtered) > num_not_deal_needed:
|
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
||||||
num_not_deal_needed = min(num_samples, len(not_deal_files_filtered))
|
|
||||||
|
|
||||||
# 随机选取文件
|
|
||||||
selected_deal_files = random.sample(deal_files_filtered, min(num_deal_needed, len(deal_files_filtered))) if deal_files_filtered else []
|
|
||||||
selected_not_deal_files = random.sample(not_deal_files_filtered, min(num_not_deal_needed, len(not_deal_files_filtered))) if not_deal_files_filtered else []
|
|
||||||
|
|
||||||
# 合并选中的文件
|
|
||||||
selected_files = selected_deal_files + selected_not_deal_files
|
|
||||||
|
|
||||||
# 如果总数不足8个,尝试从原始文件中随机选取补足
|
|
||||||
if len(selected_files) < num_samples:
|
|
||||||
all_files = deal_files + not_deal_files
|
|
||||||
# 排除已选的文件
|
|
||||||
remaining_files = [f for f in all_files if f not in selected_files]
|
|
||||||
additional_needed = num_samples - len(selected_files)
|
|
||||||
if remaining_files:
|
|
||||||
additional_files = random.sample(remaining_files, min(additional_needed, len(remaining_files)))
|
|
||||||
selected_files.extend(additional_files)
|
|
||||||
|
|
||||||
true_labels = []
|
|
||||||
for i, file_path in enumerate(selected_files):
|
|
||||||
folder_type = "未成交" if "not_deal" in file_path else "成交"
|
|
||||||
true_labels.append(folder_type)
|
|
||||||
|
|
||||||
# 使用inference_batch接口进行批量推理
|
|
||||||
if selected_files:
|
|
||||||
print("\n开始批量推理...")
|
|
||||||
try:
|
|
||||||
batch_result = engine.inference_batch(selected_files)
|
|
||||||
print(batch_result)
|
|
||||||
print(true_labels)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"推理过程中出错: {e}")
|
|
||||||
else:
|
|
||||||
print("未找到符合条件的文件进行推理")
|
|
||||||
|
|
||||||
print("\n推理端口测试完成!")
|
|
||||||
65
main.py
Normal file
65
main.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from feature_extraction import process_single
|
||||||
|
from inference import InferenceEngine
|
||||||
|
from services.mongo import voice_collection
|
||||||
|
import json,uuid
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
backbone_dir = "Qwen3-1.7B"
|
||||||
|
ckpt_path = "best_ckpt.pth"
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def get_customer_record():
|
||||||
|
cursor = voice_collection.find({
|
||||||
|
"tag": "20分钟通话",
|
||||||
|
"matched_contacts": {
|
||||||
|
"$elemMatch": {
|
||||||
|
"wecom_id": {"$exists": True, "$ne": ""}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).sort([('_id', -1)]).limit(24)
|
||||||
|
return await cursor.to_list(length=1)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
records = await get_customer_record()
|
||||||
|
for record in records:
|
||||||
|
# print(len(record["text_content"]))
|
||||||
|
data = await process_single(record["text_content"][:10000])
|
||||||
|
# print(json.dumps(data, indent=2 , ensure_ascii=False))
|
||||||
|
temp = {}
|
||||||
|
res = {}
|
||||||
|
for key ,value in data.items():
|
||||||
|
temp[key] = value.get("value") or ""
|
||||||
|
res[uuid.uuid4().hex] = temp
|
||||||
|
print(engine.inference(res))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
class Predictbody(BaseModel):
|
||||||
|
content : str
|
||||||
|
|
||||||
|
@app.post("/predict")
|
||||||
|
async def endpoint(body : Predictbody):
|
||||||
|
data = await process_single(body.content[:10000])
|
||||||
|
temp = {}
|
||||||
|
res = {}
|
||||||
|
for key ,value in data.items():
|
||||||
|
temp[key] = value.get("value") or ""
|
||||||
|
res[uuid.uuid4().hex] = temp
|
||||||
|
return {
|
||||||
|
"feature" : data,
|
||||||
|
"predict" : engine.inference(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
[project]
|
||||||
|
name = "deal-classification"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.128.0",
|
||||||
|
"matplotlib>=3.10.8",
|
||||||
|
"motor>=3.7.1",
|
||||||
|
"openai>=2.16.0",
|
||||||
|
"pandas>=3.0.0",
|
||||||
|
"python-dotenv>=1.2.1",
|
||||||
|
"torch>=2.10.0",
|
||||||
|
"transformers>=5.0.0",
|
||||||
|
"uvicorn>=0.40.0",
|
||||||
|
]
|
||||||
32
services/mongo.py
Normal file
32
services/mongo.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
|
from motor.motor_asyncio import (
|
||||||
|
AsyncIOMotorClient,
|
||||||
|
AsyncIOMotorCollection,
|
||||||
|
AsyncIOMotorDatabase,
|
||||||
|
)
|
||||||
|
|
||||||
|
MONGO_URI = os.getenv("MONGO_URI")
|
||||||
|
|
||||||
|
MongoQueryParam = int | str | None | bool | datetime.datetime
|
||||||
|
|
||||||
|
MongoQueryListParam = (
|
||||||
|
List[int] | List[str] | List[datetime.datetime] | List[bool] | List[None]
|
||||||
|
)
|
||||||
|
|
||||||
|
MongoQueryDictParam = Dict[
|
||||||
|
str, MongoQueryListParam | MongoQueryParam | "MongoQueryDictParam"
|
||||||
|
]
|
||||||
|
|
||||||
|
MongoQuery = Dict[
|
||||||
|
str, MongoQueryDictParam | MongoQueryParam | MongoQueryListParam | "MongoQuery"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 明确的类型注解
|
||||||
|
client: AsyncIOMotorClient = AsyncIOMotorClient(host=MONGO_URI)
|
||||||
|
db: AsyncIOMotorDatabase = client["TriCore"]
|
||||||
|
voice_collection: AsyncIOMotorCollection = db["voice_insight"]
|
||||||
|
|
||||||
Reference in New Issue
Block a user