import json import logging import re from datetime import datetime, timezone from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from db.models import Message, Thread from services import config_service, thread_state_service from services.llm_client import LLMClient from services.router_client import RouterClient from services.vectorstore import VectorStoreService logger = logging.getLogger(__name__) HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM FALLBACK_INTENT_CODE = "general_info" MAX_BOUNCES = 1 # сколько раз за одну реплику ветка может передать управление другой _INTENT_CHANGE_RE = re.compile(r"\[INTENT_CHANGE:\s*([a-z_][a-z0-9_]*)\s*\]") _STATE_PREFIX_RE = re.compile(r"\[STATE:\s*step=(\d+)\s*;?\s*slots\s*=\s*", re.IGNORECASE) def _auto_thread_name(first_user_text: str) -> str: preview = first_user_text.strip().replace("\n", " ") if len(preview) > 60: preview = preview[:60].rstrip() + "…" stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M") return f"{preview} · {stamp}" def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]: sources = [] for item in retrieved: meta = item.get("metadata", {}) sources.append({ "document_id": meta.get("document_id", ""), "document_name": meta.get("document_name", ""), "chunk_text": item["text"][:500], "section": meta.get("section", ""), "page": meta.get("page_number", 0), "relevance_score": round(item.get("relevance_score", 0), 3), }) return sources def _parse_assistant_signals(text: str) -> dict: """Вырезать служебные теги [INTENT_CHANGE: ...] / [STATE: ...] из ответа ассистента. Возвращает: visible_text — текст без служебных тегов, intent_change — код ветки или None, state — {'step': int, 'slots': dict} или None. Парсер толерантен к лишним пробелам; slots парсится с балансировкой фигурных скобок, чтобы не ломаться на значениях-списках типа "slots={\"a\": [1, 2]}". """ intent_match = _INTENT_CHANGE_RE.search(text) if intent_match: visible = text[:intent_match.start()].rstrip() return {"visible_text": visible, "intent_change": intent_match.group(1), "state": None} state_match = _STATE_PREFIX_RE.search(text) if state_match: tail_start = state_match.end() slots_raw, after = _consume_json_object(text, tail_start) if slots_raw is not None: remainder = text[after:].lstrip() if remainder.startswith("]"): try: slots = json.loads(slots_raw) if not isinstance(slots, dict): slots = {} except json.JSONDecodeError: slots = {} step = int(state_match.group(1)) visible = text[:state_match.start()].rstrip() return { "visible_text": visible, "intent_change": None, "state": {"step": step, "slots": slots}, } return {"visible_text": text, "intent_change": None, "state": None} def _consume_json_object(text: str, start: int) -> tuple[str | None, int]: """Вытянуть сбалансированный JSON-объект, начиная с позиции start (ожидаем `{`). Возвращает (json_string, position_after_object). При ошибке — (None, start). """ i = start n = len(text) while i < n and text[i].isspace(): i += 1 if i >= n or text[i] != "{": return None, start depth = 0 in_str = False esc = False j = i while j < n: ch = text[j] if in_str: if esc: esc = False elif ch == "\\": esc = True elif ch == '"': in_str = False else: if ch == '"': in_str = True elif ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: return text[i:j + 1], j + 1 j += 1 return None, start def _format_state_context(state_snapshot: dict) -> str: """Блок с текущим состоянием треда для дописывания в конец системного промпта.""" step = state_snapshot.get("current_step", 0) or 0 slots = state_snapshot.get("slots", {}) or {} slots_json = json.dumps(slots, ensure_ascii=False) return ( "\n\n[ТЕКУЩЕЕ СОСТОЯНИЕ]\n" f"step: {step}\n" f"slots: {slots_json}" ) async def _resolve_intent_with_fallback( session: AsyncSession, intent_code: str ) -> tuple[str, object, object]: """Вернуть (code, intent, active_cfg) — либо запрошенной ветки, либо fallback.""" pair = await config_service.get_active_config_by_intent_code(session, intent_code) if pair is None: logger.warning("Intent %r has no active config, falling back to %s", intent_code, FALLBACK_INTENT_CODE) pair = await config_service.get_active_config_by_intent_code(session, FALLBACK_INTENT_CODE) if pair is None: raise RuntimeError(f"No active config for fallback intent {FALLBACK_INTENT_CODE!r}") intent, cfg = pair return FALLBACK_INTENT_CODE, intent, cfg intent, cfg = pair return intent_code, intent, cfg async def send_message( session: AsyncSession, vectorstore: VectorStoreService, llm: LLMClient, router: RouterClient, text: str, thread_id: int | None = None, top_k: int = 5, temperature: float | None = None, max_tokens: int | None = None, ) -> dict: """Добавить реплику пациента в тред, прогнать через роутер + state machine, получить ответ.""" if thread_id is None: thread = Thread(name=_auto_thread_name(text)) session.add(thread) await session.flush() else: thread = await session.get(Thread, thread_id) if thread is None: raise LookupError(f"Thread {thread_id} not found") user_msg = Message(thread_id=thread.id, role="user", text=text) session.add(user_msg) await session.flush() stmt = ( select(Message) .where(Message.thread_id == thread.id, Message.id != user_msg.id) .order_by(Message.created_at.desc(), Message.id.desc()) .limit(HISTORY_LIMIT) ) rows = (await session.execute(stmt)).scalars().all() history = [{"role": m.role, "content": m.text} for m in reversed(rows)] # 1. Роутер — какая ветка отвечает. routing = await router.classify(session=session, history=history, text=text) router_code = routing["code"] router_version = routing.get("version") # 2. Снимок состояния треда. Если роутер ушёл в другую ветку — сбрасываем шаг и слоты. state_snapshot = await thread_state_service.load_snapshot(session, thread.id) prev_intent_code = state_snapshot["current_intent_code"] if prev_intent_code and prev_intent_code != router_code: logger.info( "Router switched intent for thread %d: %s → %s (state reset)", thread.id, prev_intent_code, router_code, ) state_snapshot = {"current_intent_code": router_code, "current_step": 0, "slots": {}} # 3. Получаем конфиг ветки (с fallback на general_info) и зовём LLM. served_code, intent, active_cfg = await _resolve_intent_with_fallback(session, router_code) if served_code != router_code: # Fallback: сбрасываем состояние на general_info. state_snapshot = {"current_intent_code": served_code, "current_step": 0, "slots": {}} retrieved = vectorstore.query(query_text=text, top_k=top_k) sources = _retrieved_to_sources(retrieved) bounce_log: list[dict] = [] last_assembled_prompt = "" llm_text = "" for attempt in range(MAX_BOUNCES + 1): base_prompt = config_service.compose_full_system_prompt(active_cfg) system_prompt = base_prompt + _format_state_context(state_snapshot) llm_result = await llm.chat( question=text, sources=retrieved, history=history, system_prompt=system_prompt, temperature=temperature, max_tokens=max_tokens, ) last_assembled_prompt = llm_result["assembled_prompt"] llm_text = llm_result["text"] parsed = _parse_assistant_signals(llm_text) if parsed["intent_change"] and attempt < MAX_BOUNCES: new_code = parsed["intent_change"] bounce_log.append({ "from": served_code, "to": new_code, "preface": parsed["visible_text"], }) logger.info( "Intent bounce in thread %d: %s → %s", thread.id, served_code, new_code, ) served_code, intent, active_cfg = await _resolve_intent_with_fallback(session, new_code) state_snapshot = {"current_intent_code": served_code, "current_step": 0, "slots": {}} continue break # 4. Обновляем thread_state и сохраняем сообщения. visible_text = parsed["visible_text"] or llm_text if parsed["state"] is not None: new_step = parsed["state"]["step"] merged_slots = {**state_snapshot.get("slots", {}), **parsed["state"]["slots"]} state_snapshot = { "current_intent_code": served_code, "current_step": new_step, "slots": merged_slots, } # Если ответ пришёл с INTENT_CHANGE на последней итерации (превысили MAX_BOUNCES) — # служебный тег мы из visible_text уже вырезали, но состояние переключать не будем. await thread_state_service.upsert( session, thread.id, intent_code=state_snapshot["current_intent_code"], step=state_snapshot["current_step"], slots=state_snapshot["slots"], ) user_msg.intent_id = intent.id if thread.agent_config_id is None: thread.agent_config_id = active_cfg.id assistant_msg = Message( thread_id=thread.id, role="assistant", text=visible_text, sources_json=json.dumps(sources, ensure_ascii=False), assembled_prompt=last_assembled_prompt, intent_id=intent.id, ) session.add(assistant_msg) thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(assistant_msg) await session.refresh(thread) logger.info( "Chat: thread=%d, router=%s, served=%s (v%d), step=%d, slots=%d keys, user_msg=%d, assistant_msg=%d, bounces=%d", thread.id, router_code, served_code, active_cfg.version, state_snapshot["current_step"], len(state_snapshot["slots"]), user_msg.id, assistant_msg.id, len(bounce_log), ) return { "thread_id": thread.id, "thread_name": thread.name, "message_id": assistant_msg.id, "intent_code": intent.code, "intent_name": intent.name, "router_intent_code": router_code, "config_version": active_cfg.version, "router_version": router_version, "answer": visible_text, "sources": sources, "model_used": llm.model, "assembled_prompt": last_assembled_prompt, "thread_state": { "current_intent_code": state_snapshot["current_intent_code"], "current_step": state_snapshot["current_step"], "slots": state_snapshot["slots"], }, "bounces": bounce_log, } async def list_threads(session: AsyncSession) -> list[dict]: count_subq = ( select(Message.thread_id, func.count(Message.id).label("cnt")) .group_by(Message.thread_id) .subquery() ) first_msg_subq = ( select(Message.thread_id, func.min(Message.id).label("first_id")) .where(Message.role == "user") .group_by(Message.thread_id) .subquery() ) stmt = ( select( Thread, func.coalesce(count_subq.c.cnt, 0).label("messages_count"), Message.text.label("first_text"), ) .outerjoin(count_subq, count_subq.c.thread_id == Thread.id) .outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id) .outerjoin(Message, Message.id == first_msg_subq.c.first_id) .order_by(Thread.updated_at.desc()) ) rows = (await session.execute(stmt)).all() result = [] for thread, messages_count, first_text in rows: preview = (first_text or "").strip().replace("\n", " ") if len(preview) > 120: preview = preview[:120].rstrip() + "…" result.append({ "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": messages_count, "first_message_preview": preview, }) return result async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None: from db.models import Intent thread = await session.get(Thread, thread_id) if thread is None: return None stmt = ( select(Message, Intent.code, Intent.name) .outerjoin(Intent, Intent.id == Message.intent_id) .where(Message.thread_id == thread_id) .order_by(Message.created_at) ) rows = (await session.execute(stmt)).all() messages = [] for m, intent_code, intent_name in rows: sources = [] if m.sources_json: try: sources = json.loads(m.sources_json) except json.JSONDecodeError: logger.warning("Bad sources_json for message %d", m.id) messages.append({ "id": m.id, "role": m.role, "text": m.text, "created_at": m.created_at.isoformat(), "sources": sources, "assembled_prompt": m.assembled_prompt or "", "intent_code": intent_code or "", "intent_name": intent_name or "", }) state = await thread_state_service.load_snapshot(session, thread_id) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages": messages, "thread_state": state, } async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None: thread = await session.get(Thread, thread_id) if thread is None: return None thread.name = name thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(thread) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": 0, "first_message_preview": "", } async def delete_thread(session: AsyncSession, thread_id: int) -> int | None: """Удалить тред и все его сообщения. Возвращает число удалённых сообщений или None, если треда нет.""" thread = await session.get(Thread, thread_id) if thread is None: return None count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id) messages_count = (await session.execute(count_stmt)).scalar_one() or 0 await session.execute(delete(Message).where(Message.thread_id == thread_id)) await session.delete(thread) await session.commit() return int(messages_count)