import json import logging from datetime import datetime, timezone from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from db.models import Message, Thread from services import config_service, intent_service from services.llm_client import LLMClient from services.router_client import RouterClient from services.vectorstore import VectorStoreService logger = logging.getLogger(__name__) HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM FALLBACK_INTENT_CODE = "general_info" def _auto_thread_name(first_user_text: str) -> str: """Авто-имя треда: первые 60 символов первой реплики + дата.""" preview = first_user_text.strip().replace("\n", " ") if len(preview) > 60: preview = preview[:60].rstrip() + "…" stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M") return f"{preview} · {stamp}" def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]: sources = [] for item in retrieved: meta = item.get("metadata", {}) sources.append({ "document_id": meta.get("document_id", ""), "document_name": meta.get("document_name", ""), "chunk_text": item["text"][:500], "section": meta.get("section", ""), "page": meta.get("page_number", 0), "relevance_score": round(item.get("relevance_score", 0), 3), }) return sources async def send_message( session: AsyncSession, vectorstore: VectorStoreService, llm: LLMClient, router: RouterClient, text: str, thread_id: int | None = None, top_k: int = 5, temperature: float | None = None, max_tokens: int | None = None, ) -> dict: """Добавить реплику пациента в тред, прогнать через роутер, получить ответ ассистента.""" if thread_id is None: thread = Thread(name=_auto_thread_name(text)) session.add(thread) await session.flush() else: thread = await session.get(Thread, thread_id) if thread is None: raise LookupError(f"Thread {thread_id} not found") # Сохраняем реплику пациента до вызова LLM — чтобы она осталась в истории даже при ошибке. user_msg = Message(thread_id=thread.id, role="user", text=text) session.add(user_msg) await session.flush() # История для классификации и для LLM: все сообщения треда до новой реплики. stmt = ( select(Message) .where(Message.thread_id == thread.id, Message.id != user_msg.id) .order_by(Message.created_at.desc(), Message.id.desc()) .limit(HISTORY_LIMIT) ) rows = (await session.execute(stmt)).scalars().all() history = [{"role": m.role, "content": m.text} for m in reversed(rows)] # 1. Роутер определяет ветку. routing = await router.classify(session=session, history=history, text=text) intent_code = routing["code"] router_version = routing.get("version") pair = await config_service.get_active_config_by_intent_code(session, intent_code) if pair is None: # Ветка выключена или без активного конфига — подстраховываемся общей справкой. logger.warning("Intent %r has no active config, falling back to %s", intent_code, FALLBACK_INTENT_CODE) intent_code = FALLBACK_INTENT_CODE pair = await config_service.get_active_config_by_intent_code(session, intent_code) if pair is None: # Даже fallback не нашёлся — критическая ошибка конфигурации. raise RuntimeError(f"No active config for fallback intent {FALLBACK_INTENT_CODE!r}") intent, active_cfg = pair system_prompt = config_service.compose_full_system_prompt(active_cfg) user_msg.intent_id = intent.id if thread.agent_config_id is None: thread.agent_config_id = active_cfg.id # 2. Retrieval + запрос к ветке. retrieved = vectorstore.query(query_text=text, top_k=top_k) sources = _retrieved_to_sources(retrieved) llm_result = await llm.chat( question=text, sources=retrieved, history=history, system_prompt=system_prompt, temperature=temperature, max_tokens=max_tokens, ) assistant_msg = Message( thread_id=thread.id, role="assistant", text=llm_result["text"], sources_json=json.dumps(sources, ensure_ascii=False), assembled_prompt=llm_result["assembled_prompt"], intent_id=intent.id, ) session.add(assistant_msg) thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(assistant_msg) await session.refresh(thread) logger.info( "Chat: thread=%d, intent=%s (v%d), user_msg=%d, assistant_msg=%d, sources=%d", thread.id, intent.code, active_cfg.version, user_msg.id, assistant_msg.id, len(sources), ) return { "thread_id": thread.id, "thread_name": thread.name, "message_id": assistant_msg.id, "intent_code": intent.code, "intent_name": intent.name, "config_version": active_cfg.version, "router_version": router_version, "answer": llm_result["text"], "sources": sources, "model_used": llm.model, "assembled_prompt": llm_result["assembled_prompt"], } async def list_threads(session: AsyncSession) -> list[dict]: """Список всех тредов с превью первой реплики и количеством сообщений.""" count_subq = ( select(Message.thread_id, func.count(Message.id).label("cnt")) .group_by(Message.thread_id) .subquery() ) first_msg_subq = ( select(Message.thread_id, func.min(Message.id).label("first_id")) .where(Message.role == "user") .group_by(Message.thread_id) .subquery() ) stmt = ( select( Thread, func.coalesce(count_subq.c.cnt, 0).label("messages_count"), Message.text.label("first_text"), ) .outerjoin(count_subq, count_subq.c.thread_id == Thread.id) .outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id) .outerjoin(Message, Message.id == first_msg_subq.c.first_id) .order_by(Thread.updated_at.desc()) ) rows = (await session.execute(stmt)).all() result = [] for thread, messages_count, first_text in rows: preview = (first_text or "").strip().replace("\n", " ") if len(preview) > 120: preview = preview[:120].rstrip() + "…" result.append({ "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": messages_count, "first_message_preview": preview, }) return result async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None: from db.models import Intent thread = await session.get(Thread, thread_id) if thread is None: return None stmt = ( select(Message, Intent.code, Intent.name) .outerjoin(Intent, Intent.id == Message.intent_id) .where(Message.thread_id == thread_id) .order_by(Message.created_at) ) rows = (await session.execute(stmt)).all() messages = [] for m, intent_code, intent_name in rows: sources = [] if m.sources_json: try: sources = json.loads(m.sources_json) except json.JSONDecodeError: logger.warning("Bad sources_json for message %d", m.id) messages.append({ "id": m.id, "role": m.role, "text": m.text, "created_at": m.created_at.isoformat(), "sources": sources, "assembled_prompt": m.assembled_prompt or "", "intent_code": intent_code or "", "intent_name": intent_name or "", }) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages": messages, } async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None: thread = await session.get(Thread, thread_id) if thread is None: return None thread.name = name thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(thread) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": 0, "first_message_preview": "", } async def delete_thread(session: AsyncSession, thread_id: int) -> int | None: """Удалить тред и все его сообщения. Возвращает число удалённых сообщений или None, если треда нет.""" thread = await session.get(Thread, thread_id) if thread is None: return None count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id) messages_count = (await session.execute(count_stmt)).scalar_one() or 0 await session.execute(delete(Message).where(Message.thread_id == thread_id)) await session.delete(thread) await session.commit() return int(messages_count)