Files
RAG_helper/services/chat_service.py
T
AR 15 M4 cac3d29273 feat(sprint5): state machine + bouncing — thread_state и служебные теги
Таблица thread_state (intent, step, slots) ведётся per-thread. В системный
промпт ветки дописывается текущее состояние, LLM возвращает служебный тег
[STATE: step=N; slots={...}] после основного ответа — парсер в chat_service
вырезает его и обновляет состояние. Если ветка решила, что тема ушла в другую,
она выдаёт [INTENT_CHANGE: code] — делаем один повторный вызов LLM с новой
веткой и сброшенным state (bouncing, MAX_BOUNCES=1). Если роутер сам выбрал
другую ветку, чем в thread_state, — state тоже сбрасывается. Промпт new_booking
переписан под 6-шаговый сценарий (имя → повод → специалист → время → подтверждение
→ запись), в «Песочнице» появился блок «Состояние треда» с intent/step/slots
и списком переходов.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-24 12:12:36 +05:00

432 lines
16 KiB
Python

import json
import logging
import re
from datetime import datetime, timezone
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from db.models import Message, Thread
from services import config_service, thread_state_service
from services.llm_client import LLMClient
from services.router_client import RouterClient
from services.vectorstore import VectorStoreService
logger = logging.getLogger(__name__)
HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM
FALLBACK_INTENT_CODE = "general_info"
MAX_BOUNCES = 1 # сколько раз за одну реплику ветка может передать управление другой
_INTENT_CHANGE_RE = re.compile(r"\[INTENT_CHANGE:\s*([a-z_][a-z0-9_]*)\s*\]")
_STATE_PREFIX_RE = re.compile(r"\[STATE:\s*step=(\d+)\s*;?\s*slots\s*=\s*", re.IGNORECASE)
def _auto_thread_name(first_user_text: str) -> str:
preview = first_user_text.strip().replace("\n", " ")
if len(preview) > 60:
preview = preview[:60].rstrip() + ""
stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M")
return f"{preview} · {stamp}"
def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]:
sources = []
for item in retrieved:
meta = item.get("metadata", {})
sources.append({
"document_id": meta.get("document_id", ""),
"document_name": meta.get("document_name", ""),
"chunk_text": item["text"][:500],
"section": meta.get("section", ""),
"page": meta.get("page_number", 0),
"relevance_score": round(item.get("relevance_score", 0), 3),
})
return sources
def _parse_assistant_signals(text: str) -> dict:
"""Вырезать служебные теги [INTENT_CHANGE: ...] / [STATE: ...] из ответа ассистента.
Возвращает:
visible_text — текст без служебных тегов,
intent_change — код ветки или None,
state — {'step': int, 'slots': dict} или None.
Парсер толерантен к лишним пробелам; slots парсится с балансировкой фигурных скобок,
чтобы не ломаться на значениях-списках типа "slots={\"a\": [1, 2]}".
"""
intent_match = _INTENT_CHANGE_RE.search(text)
if intent_match:
visible = text[:intent_match.start()].rstrip()
return {"visible_text": visible, "intent_change": intent_match.group(1), "state": None}
state_match = _STATE_PREFIX_RE.search(text)
if state_match:
tail_start = state_match.end()
slots_raw, after = _consume_json_object(text, tail_start)
if slots_raw is not None:
remainder = text[after:].lstrip()
if remainder.startswith("]"):
try:
slots = json.loads(slots_raw)
if not isinstance(slots, dict):
slots = {}
except json.JSONDecodeError:
slots = {}
step = int(state_match.group(1))
visible = text[:state_match.start()].rstrip()
return {
"visible_text": visible,
"intent_change": None,
"state": {"step": step, "slots": slots},
}
return {"visible_text": text, "intent_change": None, "state": None}
def _consume_json_object(text: str, start: int) -> tuple[str | None, int]:
"""Вытянуть сбалансированный JSON-объект, начиная с позиции start (ожидаем `{`).
Возвращает (json_string, position_after_object). При ошибке — (None, start).
"""
i = start
n = len(text)
while i < n and text[i].isspace():
i += 1
if i >= n or text[i] != "{":
return None, start
depth = 0
in_str = False
esc = False
j = i
while j < n:
ch = text[j]
if in_str:
if esc:
esc = False
elif ch == "\\":
esc = True
elif ch == '"':
in_str = False
else:
if ch == '"':
in_str = True
elif ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[i:j + 1], j + 1
j += 1
return None, start
def _format_state_context(state_snapshot: dict) -> str:
"""Блок с текущим состоянием треда для дописывания в конец системного промпта."""
step = state_snapshot.get("current_step", 0) or 0
slots = state_snapshot.get("slots", {}) or {}
slots_json = json.dumps(slots, ensure_ascii=False)
return (
"\n\n[ТЕКУЩЕЕ СОСТОЯНИЕ]\n"
f"step: {step}\n"
f"slots: {slots_json}"
)
async def _resolve_intent_with_fallback(
session: AsyncSession, intent_code: str
) -> tuple[str, object, object]:
"""Вернуть (code, intent, active_cfg) — либо запрошенной ветки, либо fallback."""
pair = await config_service.get_active_config_by_intent_code(session, intent_code)
if pair is None:
logger.warning("Intent %r has no active config, falling back to %s", intent_code, FALLBACK_INTENT_CODE)
pair = await config_service.get_active_config_by_intent_code(session, FALLBACK_INTENT_CODE)
if pair is None:
raise RuntimeError(f"No active config for fallback intent {FALLBACK_INTENT_CODE!r}")
intent, cfg = pair
return FALLBACK_INTENT_CODE, intent, cfg
intent, cfg = pair
return intent_code, intent, cfg
async def send_message(
session: AsyncSession,
vectorstore: VectorStoreService,
llm: LLMClient,
router: RouterClient,
text: str,
thread_id: int | None = None,
top_k: int = 5,
temperature: float | None = None,
max_tokens: int | None = None,
) -> dict:
"""Добавить реплику пациента в тред, прогнать через роутер + state machine, получить ответ."""
if thread_id is None:
thread = Thread(name=_auto_thread_name(text))
session.add(thread)
await session.flush()
else:
thread = await session.get(Thread, thread_id)
if thread is None:
raise LookupError(f"Thread {thread_id} not found")
user_msg = Message(thread_id=thread.id, role="user", text=text)
session.add(user_msg)
await session.flush()
stmt = (
select(Message)
.where(Message.thread_id == thread.id, Message.id != user_msg.id)
.order_by(Message.created_at.desc(), Message.id.desc())
.limit(HISTORY_LIMIT)
)
rows = (await session.execute(stmt)).scalars().all()
history = [{"role": m.role, "content": m.text} for m in reversed(rows)]
# 1. Роутер — какая ветка отвечает.
routing = await router.classify(session=session, history=history, text=text)
router_code = routing["code"]
router_version = routing.get("version")
# 2. Снимок состояния треда. Если роутер ушёл в другую ветку — сбрасываем шаг и слоты.
state_snapshot = await thread_state_service.load_snapshot(session, thread.id)
prev_intent_code = state_snapshot["current_intent_code"]
if prev_intent_code and prev_intent_code != router_code:
logger.info(
"Router switched intent for thread %d: %s%s (state reset)",
thread.id, prev_intent_code, router_code,
)
state_snapshot = {"current_intent_code": router_code, "current_step": 0, "slots": {}}
# 3. Получаем конфиг ветки (с fallback на general_info) и зовём LLM.
served_code, intent, active_cfg = await _resolve_intent_with_fallback(session, router_code)
if served_code != router_code:
# Fallback: сбрасываем состояние на general_info.
state_snapshot = {"current_intent_code": served_code, "current_step": 0, "slots": {}}
retrieved = vectorstore.query(query_text=text, top_k=top_k)
sources = _retrieved_to_sources(retrieved)
bounce_log: list[dict] = []
last_assembled_prompt = ""
llm_text = ""
for attempt in range(MAX_BOUNCES + 1):
base_prompt = config_service.compose_full_system_prompt(active_cfg)
system_prompt = base_prompt + _format_state_context(state_snapshot)
llm_result = await llm.chat(
question=text,
sources=retrieved,
history=history,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
)
last_assembled_prompt = llm_result["assembled_prompt"]
llm_text = llm_result["text"]
parsed = _parse_assistant_signals(llm_text)
if parsed["intent_change"] and attempt < MAX_BOUNCES:
new_code = parsed["intent_change"]
bounce_log.append({
"from": served_code,
"to": new_code,
"preface": parsed["visible_text"],
})
logger.info(
"Intent bounce in thread %d: %s%s", thread.id, served_code, new_code,
)
served_code, intent, active_cfg = await _resolve_intent_with_fallback(session, new_code)
state_snapshot = {"current_intent_code": served_code, "current_step": 0, "slots": {}}
continue
break
# 4. Обновляем thread_state и сохраняем сообщения.
visible_text = parsed["visible_text"] or llm_text
if parsed["state"] is not None:
new_step = parsed["state"]["step"]
merged_slots = {**state_snapshot.get("slots", {}), **parsed["state"]["slots"]}
state_snapshot = {
"current_intent_code": served_code,
"current_step": new_step,
"slots": merged_slots,
}
# Если ответ пришёл с INTENT_CHANGE на последней итерации (превысили MAX_BOUNCES) —
# служебный тег мы из visible_text уже вырезали, но состояние переключать не будем.
await thread_state_service.upsert(
session, thread.id,
intent_code=state_snapshot["current_intent_code"],
step=state_snapshot["current_step"],
slots=state_snapshot["slots"],
)
user_msg.intent_id = intent.id
if thread.agent_config_id is None:
thread.agent_config_id = active_cfg.id
assistant_msg = Message(
thread_id=thread.id,
role="assistant",
text=visible_text,
sources_json=json.dumps(sources, ensure_ascii=False),
assembled_prompt=last_assembled_prompt,
intent_id=intent.id,
)
session.add(assistant_msg)
thread.updated_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(assistant_msg)
await session.refresh(thread)
logger.info(
"Chat: thread=%d, router=%s, served=%s (v%d), step=%d, slots=%d keys, user_msg=%d, assistant_msg=%d, bounces=%d",
thread.id, router_code, served_code, active_cfg.version,
state_snapshot["current_step"], len(state_snapshot["slots"]),
user_msg.id, assistant_msg.id, len(bounce_log),
)
return {
"thread_id": thread.id,
"thread_name": thread.name,
"message_id": assistant_msg.id,
"intent_code": intent.code,
"intent_name": intent.name,
"router_intent_code": router_code,
"config_version": active_cfg.version,
"router_version": router_version,
"answer": visible_text,
"sources": sources,
"model_used": llm.model,
"assembled_prompt": last_assembled_prompt,
"thread_state": {
"current_intent_code": state_snapshot["current_intent_code"],
"current_step": state_snapshot["current_step"],
"slots": state_snapshot["slots"],
},
"bounces": bounce_log,
}
async def list_threads(session: AsyncSession) -> list[dict]:
count_subq = (
select(Message.thread_id, func.count(Message.id).label("cnt"))
.group_by(Message.thread_id)
.subquery()
)
first_msg_subq = (
select(Message.thread_id, func.min(Message.id).label("first_id"))
.where(Message.role == "user")
.group_by(Message.thread_id)
.subquery()
)
stmt = (
select(
Thread,
func.coalesce(count_subq.c.cnt, 0).label("messages_count"),
Message.text.label("first_text"),
)
.outerjoin(count_subq, count_subq.c.thread_id == Thread.id)
.outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id)
.outerjoin(Message, Message.id == first_msg_subq.c.first_id)
.order_by(Thread.updated_at.desc())
)
rows = (await session.execute(stmt)).all()
result = []
for thread, messages_count, first_text in rows:
preview = (first_text or "").strip().replace("\n", " ")
if len(preview) > 120:
preview = preview[:120].rstrip() + ""
result.append({
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages_count": messages_count,
"first_message_preview": preview,
})
return result
async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None:
from db.models import Intent
thread = await session.get(Thread, thread_id)
if thread is None:
return None
stmt = (
select(Message, Intent.code, Intent.name)
.outerjoin(Intent, Intent.id == Message.intent_id)
.where(Message.thread_id == thread_id)
.order_by(Message.created_at)
)
rows = (await session.execute(stmt)).all()
messages = []
for m, intent_code, intent_name in rows:
sources = []
if m.sources_json:
try:
sources = json.loads(m.sources_json)
except json.JSONDecodeError:
logger.warning("Bad sources_json for message %d", m.id)
messages.append({
"id": m.id,
"role": m.role,
"text": m.text,
"created_at": m.created_at.isoformat(),
"sources": sources,
"assembled_prompt": m.assembled_prompt or "",
"intent_code": intent_code or "",
"intent_name": intent_name or "",
})
state = await thread_state_service.load_snapshot(session, thread_id)
return {
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages": messages,
"thread_state": state,
}
async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None:
thread = await session.get(Thread, thread_id)
if thread is None:
return None
thread.name = name
thread.updated_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(thread)
return {
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages_count": 0,
"first_message_preview": "",
}
async def delete_thread(session: AsyncSession, thread_id: int) -> int | None:
"""Удалить тред и все его сообщения. Возвращает число удалённых сообщений или None, если треда нет."""
thread = await session.get(Thread, thread_id)
if thread is None:
return None
count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id)
messages_count = (await session.execute(count_stmt)).scalar_one() or 0
await session.execute(delete(Message).where(Message.thread_id == thread_id))
await session.delete(thread)
await session.commit()
return int(messages_count)