nahj_rag/query_pipeline.py
2026-02-17 16:52:37 +00:00

80 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rag_query_analyzer/pipeline.py
import json
from query_analyzer import build_messages
class QueryAnalysisPipeline:
def __init__(self, llm_client, model_name):
"""
llm_client → شیء کلاینت مدل زبانی شما (هر چیزی باشد)
model_name → برای شمارش توکن‌ها و ارسال
"""
self.llm = llm_client
self.model_name = model_name
async def analyze(self, user_query: str):
"""
مرحله ۱: تحلیل کوئری (سلام + زیربخش‌ها)
خروجی JSON تحلیل شده را برمی‌گرداند
"""
messages = build_messages(
user_query=user_query,
max_tokens=1024
)
result = await self.llm(messages)
# پردازش JSON خروجی
try:
parsed = json.loads(result)
except Exception:
# تلاش دوم → استخراج JSON از متن
try:
import re
json_text = re.search(r"\{.*\}", result, re.S).group(0)
parsed = json.loads(json_text)
except Exception as e:
parsed = {
"greeting_reply": "",
"sub_questions": [],
"final_answer_instruction": ""
}
print(f'final exception error: {e}')
return parsed
async def expand_sub_questions(self, sub_questions):
"""
مرحله ۲: برای هر زیربخش، پاسخ جدا تولید می‌شود
"""
answers = []
for q in sub_questions:
messages = [
{"role": "user", "content": q}
]
resp = await self.llm(messages)
answers.append({"question": q, "answer": resp})
return answers
async def final_answer(self, main_question, parts):
"""
مرحله ۳: تولید پاسخ نهایی با استفاده از:
- سوال اصلی
- پاسخ‌های جزئی
"""
merged = f"""
سوال اصلی:
{main_question}
پاسخ‌های جزئی:
{json.dumps(parts, ensure_ascii=False, indent=2)}
لطفاً پاسخ نهایی را بر اساس این بخش‌ها تولید کن.
"""
messages = [
{"role": "user", "content": merged}
]
return await self.llm(messages)