feat(sprint5): state machine + bouncing — thread_state и служебные теги
Таблица thread_state (intent, step, slots) ведётся per-thread. В системный
промпт ветки дописывается текущее состояние, LLM возвращает служебный тег
[STATE: step=N; slots={...}] после основного ответа — парсер в chat_service
вырезает его и обновляет состояние. Если ветка решила, что тема ушла в другую,
она выдаёт [INTENT_CHANGE: code] — делаем один повторный вызов LLM с новой
веткой и сброшенным state (bouncing, MAX_BOUNCES=1). Если роутер сам выбрал
другую ветку, чем в thread_state, — state тоже сбрасывается. Промпт new_booking
переписан под 6-шаговый сценарий (имя → повод → специалист → время → подтверждение
→ запись), в «Песочнице» появился блок «Состояние треда» с intent/step/slots
и списком переходов.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+204
-38
@@ -1,12 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from db.models import Message, Thread
|
||||
from services import config_service, intent_service
|
||||
from services import config_service, thread_state_service
|
||||
from services.llm_client import LLMClient
|
||||
from services.router_client import RouterClient
|
||||
from services.vectorstore import VectorStoreService
|
||||
@@ -15,10 +16,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM
|
||||
FALLBACK_INTENT_CODE = "general_info"
|
||||
MAX_BOUNCES = 1 # сколько раз за одну реплику ветка может передать управление другой
|
||||
|
||||
_INTENT_CHANGE_RE = re.compile(r"\[INTENT_CHANGE:\s*([a-z_][a-z0-9_]*)\s*\]")
|
||||
_STATE_PREFIX_RE = re.compile(r"\[STATE:\s*step=(\d+)\s*;?\s*slots\s*=\s*", re.IGNORECASE)
|
||||
|
||||
|
||||
def _auto_thread_name(first_user_text: str) -> str:
|
||||
"""Авто-имя треда: первые 60 символов первой реплики + дата."""
|
||||
preview = first_user_text.strip().replace("\n", " ")
|
||||
if len(preview) > 60:
|
||||
preview = preview[:60].rstrip() + "…"
|
||||
@@ -41,6 +45,111 @@ def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]:
|
||||
return sources
|
||||
|
||||
|
||||
def _parse_assistant_signals(text: str) -> dict:
|
||||
"""Вырезать служебные теги [INTENT_CHANGE: ...] / [STATE: ...] из ответа ассистента.
|
||||
|
||||
Возвращает:
|
||||
visible_text — текст без служебных тегов,
|
||||
intent_change — код ветки или None,
|
||||
state — {'step': int, 'slots': dict} или None.
|
||||
|
||||
Парсер толерантен к лишним пробелам; slots парсится с балансировкой фигурных скобок,
|
||||
чтобы не ломаться на значениях-списках типа "slots={\"a\": [1, 2]}".
|
||||
"""
|
||||
intent_match = _INTENT_CHANGE_RE.search(text)
|
||||
if intent_match:
|
||||
visible = text[:intent_match.start()].rstrip()
|
||||
return {"visible_text": visible, "intent_change": intent_match.group(1), "state": None}
|
||||
|
||||
state_match = _STATE_PREFIX_RE.search(text)
|
||||
if state_match:
|
||||
tail_start = state_match.end()
|
||||
slots_raw, after = _consume_json_object(text, tail_start)
|
||||
if slots_raw is not None:
|
||||
remainder = text[after:].lstrip()
|
||||
if remainder.startswith("]"):
|
||||
try:
|
||||
slots = json.loads(slots_raw)
|
||||
if not isinstance(slots, dict):
|
||||
slots = {}
|
||||
except json.JSONDecodeError:
|
||||
slots = {}
|
||||
step = int(state_match.group(1))
|
||||
visible = text[:state_match.start()].rstrip()
|
||||
return {
|
||||
"visible_text": visible,
|
||||
"intent_change": None,
|
||||
"state": {"step": step, "slots": slots},
|
||||
}
|
||||
|
||||
return {"visible_text": text, "intent_change": None, "state": None}
|
||||
|
||||
|
||||
def _consume_json_object(text: str, start: int) -> tuple[str | None, int]:
|
||||
"""Вытянуть сбалансированный JSON-объект, начиная с позиции start (ожидаем `{`).
|
||||
|
||||
Возвращает (json_string, position_after_object). При ошибке — (None, start).
|
||||
"""
|
||||
i = start
|
||||
n = len(text)
|
||||
while i < n and text[i].isspace():
|
||||
i += 1
|
||||
if i >= n or text[i] != "{":
|
||||
return None, start
|
||||
depth = 0
|
||||
in_str = False
|
||||
esc = False
|
||||
j = i
|
||||
while j < n:
|
||||
ch = text[j]
|
||||
if in_str:
|
||||
if esc:
|
||||
esc = False
|
||||
elif ch == "\\":
|
||||
esc = True
|
||||
elif ch == '"':
|
||||
in_str = False
|
||||
else:
|
||||
if ch == '"':
|
||||
in_str = True
|
||||
elif ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[i:j + 1], j + 1
|
||||
j += 1
|
||||
return None, start
|
||||
|
||||
|
||||
def _format_state_context(state_snapshot: dict) -> str:
|
||||
"""Блок с текущим состоянием треда для дописывания в конец системного промпта."""
|
||||
step = state_snapshot.get("current_step", 0) or 0
|
||||
slots = state_snapshot.get("slots", {}) or {}
|
||||
slots_json = json.dumps(slots, ensure_ascii=False)
|
||||
return (
|
||||
"\n\n[ТЕКУЩЕЕ СОСТОЯНИЕ]\n"
|
||||
f"step: {step}\n"
|
||||
f"slots: {slots_json}"
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_intent_with_fallback(
|
||||
session: AsyncSession, intent_code: str
|
||||
) -> tuple[str, object, object]:
|
||||
"""Вернуть (code, intent, active_cfg) — либо запрошенной ветки, либо fallback."""
|
||||
pair = await config_service.get_active_config_by_intent_code(session, intent_code)
|
||||
if pair is None:
|
||||
logger.warning("Intent %r has no active config, falling back to %s", intent_code, FALLBACK_INTENT_CODE)
|
||||
pair = await config_service.get_active_config_by_intent_code(session, FALLBACK_INTENT_CODE)
|
||||
if pair is None:
|
||||
raise RuntimeError(f"No active config for fallback intent {FALLBACK_INTENT_CODE!r}")
|
||||
intent, cfg = pair
|
||||
return FALLBACK_INTENT_CODE, intent, cfg
|
||||
intent, cfg = pair
|
||||
return intent_code, intent, cfg
|
||||
|
||||
|
||||
async def send_message(
|
||||
session: AsyncSession,
|
||||
vectorstore: VectorStoreService,
|
||||
@@ -52,7 +161,7 @@ async def send_message(
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> dict:
|
||||
"""Добавить реплику пациента в тред, прогнать через роутер, получить ответ ассистента."""
|
||||
"""Добавить реплику пациента в тред, прогнать через роутер + state machine, получить ответ."""
|
||||
if thread_id is None:
|
||||
thread = Thread(name=_auto_thread_name(text))
|
||||
session.add(thread)
|
||||
@@ -62,12 +171,10 @@ async def send_message(
|
||||
if thread is None:
|
||||
raise LookupError(f"Thread {thread_id} not found")
|
||||
|
||||
# Сохраняем реплику пациента до вызова LLM — чтобы она осталась в истории даже при ошибке.
|
||||
user_msg = Message(thread_id=thread.id, role="user", text=text)
|
||||
session.add(user_msg)
|
||||
await session.flush()
|
||||
|
||||
# История для классификации и для LLM: все сообщения треда до новой реплики.
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.thread_id == thread.id, Message.id != user_msg.id)
|
||||
@@ -77,47 +184,95 @@ async def send_message(
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
history = [{"role": m.role, "content": m.text} for m in reversed(rows)]
|
||||
|
||||
# 1. Роутер определяет ветку.
|
||||
# 1. Роутер — какая ветка отвечает.
|
||||
routing = await router.classify(session=session, history=history, text=text)
|
||||
intent_code = routing["code"]
|
||||
router_code = routing["code"]
|
||||
router_version = routing.get("version")
|
||||
pair = await config_service.get_active_config_by_intent_code(session, intent_code)
|
||||
if pair is None:
|
||||
# Ветка выключена или без активного конфига — подстраховываемся общей справкой.
|
||||
logger.warning("Intent %r has no active config, falling back to %s", intent_code, FALLBACK_INTENT_CODE)
|
||||
intent_code = FALLBACK_INTENT_CODE
|
||||
pair = await config_service.get_active_config_by_intent_code(session, intent_code)
|
||||
|
||||
if pair is None:
|
||||
# Даже fallback не нашёлся — критическая ошибка конфигурации.
|
||||
raise RuntimeError(f"No active config for fallback intent {FALLBACK_INTENT_CODE!r}")
|
||||
# 2. Снимок состояния треда. Если роутер ушёл в другую ветку — сбрасываем шаг и слоты.
|
||||
state_snapshot = await thread_state_service.load_snapshot(session, thread.id)
|
||||
prev_intent_code = state_snapshot["current_intent_code"]
|
||||
if prev_intent_code and prev_intent_code != router_code:
|
||||
logger.info(
|
||||
"Router switched intent for thread %d: %s → %s (state reset)",
|
||||
thread.id, prev_intent_code, router_code,
|
||||
)
|
||||
state_snapshot = {"current_intent_code": router_code, "current_step": 0, "slots": {}}
|
||||
|
||||
intent, active_cfg = pair
|
||||
system_prompt = config_service.compose_full_system_prompt(active_cfg)
|
||||
# 3. Получаем конфиг ветки (с fallback на general_info) и зовём LLM.
|
||||
served_code, intent, active_cfg = await _resolve_intent_with_fallback(session, router_code)
|
||||
if served_code != router_code:
|
||||
# Fallback: сбрасываем состояние на general_info.
|
||||
state_snapshot = {"current_intent_code": served_code, "current_step": 0, "slots": {}}
|
||||
|
||||
retrieved = vectorstore.query(query_text=text, top_k=top_k)
|
||||
sources = _retrieved_to_sources(retrieved)
|
||||
|
||||
bounce_log: list[dict] = []
|
||||
last_assembled_prompt = ""
|
||||
llm_text = ""
|
||||
|
||||
for attempt in range(MAX_BOUNCES + 1):
|
||||
base_prompt = config_service.compose_full_system_prompt(active_cfg)
|
||||
system_prompt = base_prompt + _format_state_context(state_snapshot)
|
||||
|
||||
llm_result = await llm.chat(
|
||||
question=text,
|
||||
sources=retrieved,
|
||||
history=history,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
last_assembled_prompt = llm_result["assembled_prompt"]
|
||||
llm_text = llm_result["text"]
|
||||
parsed = _parse_assistant_signals(llm_text)
|
||||
|
||||
if parsed["intent_change"] and attempt < MAX_BOUNCES:
|
||||
new_code = parsed["intent_change"]
|
||||
bounce_log.append({
|
||||
"from": served_code,
|
||||
"to": new_code,
|
||||
"preface": parsed["visible_text"],
|
||||
})
|
||||
logger.info(
|
||||
"Intent bounce in thread %d: %s → %s", thread.id, served_code, new_code,
|
||||
)
|
||||
served_code, intent, active_cfg = await _resolve_intent_with_fallback(session, new_code)
|
||||
state_snapshot = {"current_intent_code": served_code, "current_step": 0, "slots": {}}
|
||||
continue
|
||||
break
|
||||
|
||||
# 4. Обновляем thread_state и сохраняем сообщения.
|
||||
visible_text = parsed["visible_text"] or llm_text
|
||||
if parsed["state"] is not None:
|
||||
new_step = parsed["state"]["step"]
|
||||
merged_slots = {**state_snapshot.get("slots", {}), **parsed["state"]["slots"]}
|
||||
state_snapshot = {
|
||||
"current_intent_code": served_code,
|
||||
"current_step": new_step,
|
||||
"slots": merged_slots,
|
||||
}
|
||||
# Если ответ пришёл с INTENT_CHANGE на последней итерации (превысили MAX_BOUNCES) —
|
||||
# служебный тег мы из visible_text уже вырезали, но состояние переключать не будем.
|
||||
|
||||
await thread_state_service.upsert(
|
||||
session, thread.id,
|
||||
intent_code=state_snapshot["current_intent_code"],
|
||||
step=state_snapshot["current_step"],
|
||||
slots=state_snapshot["slots"],
|
||||
)
|
||||
|
||||
user_msg.intent_id = intent.id
|
||||
if thread.agent_config_id is None:
|
||||
thread.agent_config_id = active_cfg.id
|
||||
|
||||
# 2. Retrieval + запрос к ветке.
|
||||
retrieved = vectorstore.query(query_text=text, top_k=top_k)
|
||||
sources = _retrieved_to_sources(retrieved)
|
||||
|
||||
llm_result = await llm.chat(
|
||||
question=text,
|
||||
sources=retrieved,
|
||||
history=history,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
assistant_msg = Message(
|
||||
thread_id=thread.id,
|
||||
role="assistant",
|
||||
text=llm_result["text"],
|
||||
text=visible_text,
|
||||
sources_json=json.dumps(sources, ensure_ascii=False),
|
||||
assembled_prompt=llm_result["assembled_prompt"],
|
||||
assembled_prompt=last_assembled_prompt,
|
||||
intent_id=intent.id,
|
||||
)
|
||||
session.add(assistant_msg)
|
||||
@@ -129,8 +284,10 @@ async def send_message(
|
||||
await session.refresh(thread)
|
||||
|
||||
logger.info(
|
||||
"Chat: thread=%d, intent=%s (v%d), user_msg=%d, assistant_msg=%d, sources=%d",
|
||||
thread.id, intent.code, active_cfg.version, user_msg.id, assistant_msg.id, len(sources),
|
||||
"Chat: thread=%d, router=%s, served=%s (v%d), step=%d, slots=%d keys, user_msg=%d, assistant_msg=%d, bounces=%d",
|
||||
thread.id, router_code, served_code, active_cfg.version,
|
||||
state_snapshot["current_step"], len(state_snapshot["slots"]),
|
||||
user_msg.id, assistant_msg.id, len(bounce_log),
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -139,17 +296,23 @@ async def send_message(
|
||||
"message_id": assistant_msg.id,
|
||||
"intent_code": intent.code,
|
||||
"intent_name": intent.name,
|
||||
"router_intent_code": router_code,
|
||||
"config_version": active_cfg.version,
|
||||
"router_version": router_version,
|
||||
"answer": llm_result["text"],
|
||||
"answer": visible_text,
|
||||
"sources": sources,
|
||||
"model_used": llm.model,
|
||||
"assembled_prompt": llm_result["assembled_prompt"],
|
||||
"assembled_prompt": last_assembled_prompt,
|
||||
"thread_state": {
|
||||
"current_intent_code": state_snapshot["current_intent_code"],
|
||||
"current_step": state_snapshot["current_step"],
|
||||
"slots": state_snapshot["slots"],
|
||||
},
|
||||
"bounces": bounce_log,
|
||||
}
|
||||
|
||||
|
||||
async def list_threads(session: AsyncSession) -> list[dict]:
|
||||
"""Список всех тредов с превью первой реплики и количеством сообщений."""
|
||||
count_subq = (
|
||||
select(Message.thread_id, func.count(Message.id).label("cnt"))
|
||||
.group_by(Message.thread_id)
|
||||
@@ -224,12 +387,15 @@ async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | Non
|
||||
"intent_code": intent_code or "",
|
||||
"intent_name": intent_name or "",
|
||||
})
|
||||
|
||||
state = await thread_state_service.load_snapshot(session, thread_id)
|
||||
return {
|
||||
"id": thread.id,
|
||||
"name": thread.name,
|
||||
"created_at": thread.created_at.isoformat(),
|
||||
"updated_at": thread.updated_at.isoformat(),
|
||||
"messages": messages,
|
||||
"thread_state": state,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
"""State machine треда: текущая ветка, шаг внутри ветки, собранные слоты.
|
||||
|
||||
Используется в chat_service для ведения многошаговых сценариев (Спринт 5).
|
||||
Слоты — произвольный JSON-словарь, конкретные ключи определяются веткой.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from db.models import ThreadState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_state(session: AsyncSession, thread_id: int) -> ThreadState | None:
|
||||
return await session.get(ThreadState, thread_id)
|
||||
|
||||
|
||||
def _parse_slots(raw: str) -> dict:
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
value = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Bad slots_json for thread_state, resetting to {}")
|
||||
return {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
async def load_snapshot(session: AsyncSession, thread_id: int) -> dict:
|
||||
"""Удобный снимок состояния для чтения (intent, step, slots)."""
|
||||
state = await get_state(session, thread_id)
|
||||
if state is None:
|
||||
return {"current_intent_code": None, "current_step": 0, "slots": {}}
|
||||
return {
|
||||
"current_intent_code": state.current_intent_code,
|
||||
"current_step": state.current_step,
|
||||
"slots": _parse_slots(state.slots_json),
|
||||
}
|
||||
|
||||
|
||||
async def upsert(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
*,
|
||||
intent_code: str | None,
|
||||
step: int,
|
||||
slots: dict,
|
||||
) -> ThreadState:
|
||||
"""Создать или обновить состояние треда. Коммит — на совести вызывающего."""
|
||||
state = await get_state(session, thread_id)
|
||||
now = datetime.now(timezone.utc)
|
||||
slots_raw = json.dumps(slots or {}, ensure_ascii=False)
|
||||
if state is None:
|
||||
state = ThreadState(
|
||||
thread_id=thread_id,
|
||||
current_intent_code=intent_code,
|
||||
current_step=step,
|
||||
slots_json=slots_raw,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(state)
|
||||
else:
|
||||
state.current_intent_code = intent_code
|
||||
state.current_step = step
|
||||
state.slots_json = slots_raw
|
||||
state.updated_at = now
|
||||
return state
|
||||
|
||||
|
||||
async def reset(session: AsyncSession, thread_id: int, *, new_intent_code: str | None) -> ThreadState:
|
||||
"""Сбросить шаг и слоты треда, выставить новую ветку (при смене intent)."""
|
||||
return await upsert(session, thread_id, intent_code=new_intent_code, step=0, slots={})
|
||||
Reference in New Issue
Block a user