import json import logging from datetime import datetime, timezone from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from db.models import Message, Thread from services import config_service from services.llm_client import LLMClient from services.vectorstore import VectorStoreService logger = logging.getLogger(__name__) HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM def _auto_thread_name(first_user_text: str) -> str: """Авто-имя треда: первые 60 символов первой реплики + дата.""" preview = first_user_text.strip().replace("\n", " ") if len(preview) > 60: preview = preview[:60].rstrip() + "…" stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M") return f"{preview} · {stamp}" def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]: sources = [] for item in retrieved: meta = item.get("metadata", {}) sources.append({ "document_id": meta.get("document_id", ""), "document_name": meta.get("document_name", ""), "chunk_text": item["text"][:500], "section": meta.get("section", ""), "page": meta.get("page_number", 0), "relevance_score": round(item.get("relevance_score", 0), 3), }) return sources async def send_message( session: AsyncSession, vectorstore: VectorStoreService, llm: LLMClient, text: str, thread_id: int | None = None, top_k: int = 5, temperature: float | None = None, max_tokens: int | None = None, ) -> dict: """Добавить реплику пациента в тред, получить ответ ассистента, сохранить оба сообщения.""" active_cfg = await config_service.get_active_config(session) system_prompt = config_service.compose_full_system_prompt(active_cfg) if active_cfg else None if thread_id is None: thread = Thread( name=_auto_thread_name(text), agent_config_id=active_cfg.id if active_cfg else None, ) session.add(thread) await session.flush() else: thread = await session.get(Thread, thread_id) if thread is None: raise LookupError(f"Thread {thread_id} not found") # Сохраняем реплику пациента до вызова LLM — чтобы она осталась в истории даже при ошибке. user_msg = Message(thread_id=thread.id, role="user", text=text) session.add(user_msg) await session.flush() retrieved = vectorstore.query(query_text=text, top_k=top_k) sources = _retrieved_to_sources(retrieved) # История для LLM: все сообщения треда, кроме только что добавленной user-реплики. stmt = ( select(Message) .where(Message.thread_id == thread.id, Message.id != user_msg.id) .order_by(Message.created_at.desc(), Message.id.desc()) .limit(HISTORY_LIMIT) ) rows = (await session.execute(stmt)).scalars().all() history = [{"role": m.role, "content": m.text} for m in reversed(rows)] llm_result = await llm.chat( question=text, sources=retrieved, history=history, system_prompt=system_prompt, temperature=temperature, max_tokens=max_tokens, ) assistant_msg = Message( thread_id=thread.id, role="assistant", text=llm_result["text"], sources_json=json.dumps(sources, ensure_ascii=False), assembled_prompt=llm_result["assembled_prompt"], ) session.add(assistant_msg) thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(assistant_msg) await session.refresh(thread) logger.info("Chat: thread=%d, user_msg=%d, assistant_msg=%d, sources=%d", thread.id, user_msg.id, assistant_msg.id, len(sources)) return { "thread_id": thread.id, "thread_name": thread.name, "message_id": assistant_msg.id, "answer": llm_result["text"], "sources": sources, "model_used": llm.model, "assembled_prompt": llm_result["assembled_prompt"], } async def list_threads(session: AsyncSession) -> list[dict]: """Список всех тредов с превью первой реплики и количеством сообщений.""" count_subq = ( select(Message.thread_id, func.count(Message.id).label("cnt")) .group_by(Message.thread_id) .subquery() ) first_msg_subq = ( select(Message.thread_id, func.min(Message.id).label("first_id")) .where(Message.role == "user") .group_by(Message.thread_id) .subquery() ) stmt = ( select( Thread, func.coalesce(count_subq.c.cnt, 0).label("messages_count"), Message.text.label("first_text"), ) .outerjoin(count_subq, count_subq.c.thread_id == Thread.id) .outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id) .outerjoin(Message, Message.id == first_msg_subq.c.first_id) .order_by(Thread.updated_at.desc()) ) rows = (await session.execute(stmt)).all() result = [] for thread, messages_count, first_text in rows: preview = (first_text or "").strip().replace("\n", " ") if len(preview) > 120: preview = preview[:120].rstrip() + "…" result.append({ "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": messages_count, "first_message_preview": preview, }) return result async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None: stmt = select(Thread).where(Thread.id == thread_id).options(selectinload(Thread.messages)) thread = (await session.execute(stmt)).scalar_one_or_none() if thread is None: return None messages = [] for m in thread.messages: sources = [] if m.sources_json: try: sources = json.loads(m.sources_json) except json.JSONDecodeError: logger.warning("Bad sources_json for message %d", m.id) messages.append({ "id": m.id, "role": m.role, "text": m.text, "created_at": m.created_at.isoformat(), "sources": sources, "assembled_prompt": m.assembled_prompt or "", }) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages": messages, } async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None: thread = await session.get(Thread, thread_id) if thread is None: return None thread.name = name thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(thread) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": 0, "first_message_preview": "", } async def delete_thread(session: AsyncSession, thread_id: int) -> int | None: """Удалить тред и все его сообщения. Возвращает число удалённых сообщений или None, если треда нет.""" thread = await session.get(Thread, thread_id) if thread is None: return None count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id) messages_count = (await session.execute(count_stmt)).scalar_one() or 0 await session.execute(delete(Message).where(Message.thread_id == thread_id)) await session.delete(thread) await session.commit() return int(messages_count)