import json import logging from datetime import datetime, timezone from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from db.models import IntentStep, Message, Thread from services import config_service, intent_step_service, thread_state_service from services.llm_client import LLMClient, LLMUnavailableError from services.router_client import RouterClient from services.state_machine import parse_branch_response, validate_transition from services.vectorstore import VectorStoreService logger = logging.getLogger(__name__) HISTORY_LIMIT = 20 FALLBACK_INTENT_CODE = "general_info" MAX_BOUNCES = 1 def _auto_thread_name(first_user_text: str) -> str: preview = first_user_text.strip().replace("\n", " ") if len(preview) > 60: preview = preview[:60].rstrip() + "…" stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M") return f"{preview} · {stamp}" def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]: sources = [] for item in retrieved: meta = item.get("metadata", {}) sources.append({ "document_id": meta.get("document_id", ""), "document_name": meta.get("document_name", ""), "chunk_text": item["text"][:500], "section": meta.get("section", ""), "page": meta.get("page_number", 0), "relevance_score": round(item.get("relevance_score", 0), 3), }) return sources def _format_state_context( snapshot: dict, current_step: IntentStep | None, router_hint: str | None = None, ) -> str: """Блок с текущим состоянием треда для дописывания в системный промпт.""" slots = snapshot.get("slots", {}) or {} slots_json = json.dumps(slots, ensure_ascii=False) lines = ["", "[ТЕКУЩЕЕ СОСТОЯНИЕ]"] if current_step is not None: allowed = intent_step_service.parse_allowed_next(current_step) lines.append(f"step_code: {current_step.code} ({current_step.name})") lines.append(f"allowed_next: {json.dumps(allowed, ensure_ascii=False)}") else: lines.append("step_code: —") lines.append(f"slots: {slots_json}") if router_hint: lines.append("") lines.append("[ПОДСКАЗКА РОУТЕРА]") lines.append(router_hint) return "\n" + "\n".join(lines) async def _resolve_intent_with_fallback( session: AsyncSession, intent_code: str ) -> tuple[str, object, object]: pair = await config_service.get_active_config_by_intent_code(session, intent_code) if pair is None: logger.warning("Intent %r has no active config, falling back to %s", intent_code, FALLBACK_INTENT_CODE) pair = await config_service.get_active_config_by_intent_code(session, FALLBACK_INTENT_CODE) if pair is None: raise RuntimeError(f"No active config for fallback intent {FALLBACK_INTENT_CODE!r}") intent, cfg = pair return FALLBACK_INTENT_CODE, intent, cfg intent, cfg = pair return intent_code, intent, cfg async def _resolve_current_step( session: AsyncSession, intent_id: int, intent_code: str, step_code: str | None, ) -> IntentStep | None: """Найти шаг state machine для текущего состояния. Если кода нет — взять первый шаг ветки.""" if not intent_step_service.has_state_machine(intent_code): return None if step_code: step = await intent_step_service.get_step_by_code(session, intent_id, step_code) if step is not None: return step logger.warning("Step %r not found for intent %s, falling back to first step", step_code, intent_code) return await intent_step_service.get_first_step(session, intent_id) async def send_message( session: AsyncSession, vectorstore: VectorStoreService, llm: LLMClient, router: RouterClient, text: str, thread_id: int | None = None, top_k: int = 5, temperature: float | None = None, max_tokens: int | None = None, ) -> dict: """Обработать реплику пациента: роутер → state machine → LLM → ответ. Важно: коммит транзакции делается только в самом конце. Если LLM упадёт — rollback в роутере откатит thread + user_msg, чтобы «пустые» диалоги без ответа ассистента не висели в списке. """ if thread_id is None: thread = Thread(name=_auto_thread_name(text)) session.add(thread) await session.flush() else: thread = await session.get(Thread, thread_id) if thread is None: raise LookupError(f"Thread {thread_id} not found") user_msg = Message(thread_id=thread.id, role="user", text=text) session.add(user_msg) await session.flush() # только flush, без commit — чтобы откатить при ошибке LLM stmt = ( select(Message) .where(Message.thread_id == thread.id, Message.id != user_msg.id) .order_by(Message.created_at.desc(), Message.id.desc()) .limit(HISTORY_LIMIT) ) rows = (await session.execute(stmt)).scalars().all() history = [{"role": m.role, "content": m.text} for m in reversed(rows)] # 1. Роутер — куда направляем. routing = await router.classify(session=session, history=history, text=text) router_code = routing["code"] router_version = routing.get("version") # 2. Снимок состояния. Важное правило (sticky state machine, мини-G из Спринта 6b): # если тред уже идёт по state-machine-ветке и роутер предлагает другую — # НЕ сбрасываем state. Передадим LLM подсказку «роутер думает так», и пусть # она сама решает: выдать `[INTENT_CHANGE: ...]` или удержать сценарий. # Это нужно, чтобы фраза-повод («болит ухо») внутри записи не сбрасывала слоты. snapshot = await thread_state_service.load_snapshot(session, thread.id) prev_intent_code = snapshot["current_intent_code"] router_hint: str | None = None effective_code = router_code if prev_intent_code and prev_intent_code != router_code: if intent_step_service.has_state_machine(prev_intent_code): logger.info( "Router suggested %s but thread %d is in sm %s — sticky, hint only", router_code, thread.id, prev_intent_code, ) router_hint = ( f"Роутер на этой реплике счёл, что тема — `{router_code}`. " f"Ты сейчас ведёшь сценарий `{prev_intent_code}`. " f"Если пациент действительно сменил тему (перенос, цены, острое состояние) — " f"выдай `[INTENT_CHANGE: {router_code}]`. " f"Если реплика укладывается в сценарий (повод/жалоба/имя) — " f"зафиксируй её в соответствующий слот и продолжай по сценарию." ) effective_code = prev_intent_code else: logger.info( "Router switched intent for thread %d: %s → %s (state reset)", thread.id, prev_intent_code, router_code, ) snapshot = { "current_intent_code": router_code, "current_step": 0, "current_step_code": None, "slots": {}, } # 3. Разрешаем ветку (с fallback) и шаг. served_code, intent, active_cfg = await _resolve_intent_with_fallback(session, effective_code) if served_code != effective_code: snapshot = { "current_intent_code": served_code, "current_step": 0, "current_step_code": None, "slots": {}, } router_hint = None retrieved = vectorstore.query(query_text=text, top_k=top_k) sources = _retrieved_to_sources(retrieved) bounce_log: list[dict] = [] validation_events: list[dict] = [] # illegal transitions для UI-подсветки last_assembled_prompt = "" visible_text = "" parse_error: str | None = None is_state_machine = False for attempt in range(MAX_BOUNCES + 1): current_step = await _resolve_current_step( session, intent.id, served_code, snapshot.get("current_step_code"), ) is_state_machine = current_step is not None if current_step is not None and snapshot.get("current_step_code") != current_step.code: snapshot["current_step_code"] = current_step.code base_prompt = config_service.compose_full_system_prompt(active_cfg) step_prompt = f"\n\n{current_step.system_prompt}" if current_step else "" state_context = _format_state_context(snapshot, current_step, router_hint) system_prompt = base_prompt + step_prompt + state_context llm_result = await llm.chat( question=text, sources=retrieved, history=history, system_prompt=system_prompt, temperature=temperature, max_tokens=max_tokens, ) last_assembled_prompt = llm_result["assembled_prompt"] parsed = parse_branch_response(llm_result["text"]) visible_text = parsed["visible_text"] or llm_result["text"] # STATE_JSON-блок ждём только от state-machine-веток. У остальных (general_info, # price_question и т.п.) «no STATE_JSON» — ожидаемое состояние, не ошибка. parse_error = parsed["parse_error"] if is_state_machine else None if parsed["intent_change"] and attempt < MAX_BOUNCES: new_code = parsed["intent_change"] bounce_log.append({ "from": served_code, "to": new_code, "preface": parsed["visible_text"], }) logger.info("Intent bounce in thread %d: %s → %s", thread.id, served_code, new_code) served_code, intent, active_cfg = await _resolve_intent_with_fallback(session, new_code) snapshot = { "current_intent_code": served_code, "current_step": 0, "current_step_code": None, "slots": {}, } router_hint = None # новая ветка — подсказка больше неактуальна continue if parsed["state_update"] is not None and current_step is not None: requested = parsed["state_update"]["state_after"] allowed = intent_step_service.parse_allowed_next(current_step) ok, reason = validate_transition( current_step=current_step.code, requested_step=requested, allowed_next=allowed, ) slots_updated = parsed["state_update"]["slots_updated"] merged_slots = {**snapshot.get("slots", {}), **slots_updated} if ok: snapshot = { "current_intent_code": served_code, "current_step": snapshot["current_step"] + (1 if requested != current_step.code else 0), "current_step_code": requested, "slots": merged_slots, } else: logger.warning( "Illegal state_after in thread %d (%s): %s", thread.id, served_code, reason, ) validation_events.append({ "current_step": current_step.code, "requested_step": requested, "reason": reason, }) # Слоты всё равно мёржим (информация полезная), шаг не двигаем. snapshot = { "current_intent_code": served_code, "current_step": snapshot["current_step"], "current_step_code": current_step.code, "slots": merged_slots, } elif parsed["state_update"] is None and current_step is not None and parse_error: logger.warning( "State machine branch %s returned no STATE_JSON: %s", served_code, parse_error, ) break # 4. Сохраняем: thread_state пишется ПОСЛЕ всей логики, коммит — единой транзакцией. await thread_state_service.upsert( session, thread.id, intent_code=snapshot["current_intent_code"], step=snapshot["current_step"], step_code=snapshot.get("current_step_code"), slots=snapshot["slots"], ) user_msg.intent_id = intent.id if thread.agent_config_id is None: thread.agent_config_id = active_cfg.id assistant_msg = Message( thread_id=thread.id, role="assistant", text=visible_text, sources_json=json.dumps(sources, ensure_ascii=False), assembled_prompt=last_assembled_prompt, intent_id=intent.id, ) session.add(assistant_msg) thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(assistant_msg) await session.refresh(thread) logger.info( "Chat: thread=%d, router=%s, served=%s (v%d), step=%s, slots=%d keys, bounces=%d, validation_events=%d", thread.id, router_code, served_code, active_cfg.version, snapshot.get("current_step_code") or "-", len(snapshot["slots"]), len(bounce_log), len(validation_events), ) return { "thread_id": thread.id, "thread_name": thread.name, "message_id": assistant_msg.id, "intent_code": intent.code, "intent_name": intent.name, "router_intent_code": router_code, "config_version": active_cfg.version, "router_version": router_version, "answer": visible_text, "sources": sources, "model_used": llm.model, "assembled_prompt": last_assembled_prompt, "thread_state": { "current_intent_code": snapshot["current_intent_code"], "current_step": snapshot["current_step"], "current_step_code": snapshot.get("current_step_code"), "slots": snapshot["slots"], }, "bounces": bounce_log, "validation_events": validation_events, "parse_error": parse_error, } async def list_threads(session: AsyncSession) -> list[dict]: count_subq = ( select(Message.thread_id, func.count(Message.id).label("cnt")) .group_by(Message.thread_id) .subquery() ) first_msg_subq = ( select(Message.thread_id, func.min(Message.id).label("first_id")) .where(Message.role == "user") .group_by(Message.thread_id) .subquery() ) stmt = ( select( Thread, func.coalesce(count_subq.c.cnt, 0).label("messages_count"), Message.text.label("first_text"), ) .outerjoin(count_subq, count_subq.c.thread_id == Thread.id) .outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id) .outerjoin(Message, Message.id == first_msg_subq.c.first_id) .order_by(Thread.updated_at.desc()) ) rows = (await session.execute(stmt)).all() result = [] for thread, messages_count, first_text in rows: preview = (first_text or "").strip().replace("\n", " ") if len(preview) > 120: preview = preview[:120].rstrip() + "…" result.append({ "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": messages_count, "first_message_preview": preview, }) return result async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None: from db.models import Intent thread = await session.get(Thread, thread_id) if thread is None: return None stmt = ( select(Message, Intent.code, Intent.name) .outerjoin(Intent, Intent.id == Message.intent_id) .where(Message.thread_id == thread_id) .order_by(Message.created_at) ) rows = (await session.execute(stmt)).all() messages = [] for m, intent_code, intent_name in rows: sources = [] if m.sources_json: try: sources = json.loads(m.sources_json) except json.JSONDecodeError: logger.warning("Bad sources_json for message %d", m.id) messages.append({ "id": m.id, "role": m.role, "text": m.text, "created_at": m.created_at.isoformat(), "sources": sources, "assembled_prompt": m.assembled_prompt or "", "intent_code": intent_code or "", "intent_name": intent_name or "", }) state = await thread_state_service.load_snapshot(session, thread_id) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages": messages, "thread_state": state, } async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None: thread = await session.get(Thread, thread_id) if thread is None: return None thread.name = name thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(thread) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": 0, "first_message_preview": "", } async def delete_thread(session: AsyncSession, thread_id: int) -> int | None: thread = await session.get(Thread, thread_id) if thread is None: return None count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id) messages_count = (await session.execute(count_stmt)).scalar_one() or 0 await session.execute(delete(Message).where(Message.thread_id == thread_id)) await session.delete(thread) await session.commit() return int(messages_count)