import json import logging from datetime import datetime, timezone from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from db.models import Message, Thread from services.llm_client import LLMClient from services.vectorstore import VectorStoreService logger = logging.getLogger(__name__) HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM def _auto_thread_name(first_user_text: str) -> str: """Авто-имя треда: первые 60 символов первой реплики + дата.""" preview = first_user_text.strip().replace("\n", " ") if len(preview) > 60: preview = preview[:60].rstrip() + "…" stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M") return f"{preview} · {stamp}" def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]: sources = [] for item in retrieved: meta = item.get("metadata", {}) sources.append({ "document_id": meta.get("document_id", ""), "document_name": meta.get("document_name", ""), "chunk_text": item["text"][:500], "section": meta.get("section", ""), "page": meta.get("page_number", 0), "relevance_score": round(item.get("relevance_score", 0), 3), }) return sources async def send_message( session: AsyncSession, vectorstore: VectorStoreService, llm: LLMClient, text: str, thread_id: int | None = None, top_k: int = 5, temperature: float | None = None, max_tokens: int | None = None, ) -> dict: """Добавить реплику пациента в тред, получить ответ ассистента, сохранить оба сообщения.""" if thread_id is None: thread = Thread(name=_auto_thread_name(text)) session.add(thread) await session.flush() else: thread = await session.get(Thread, thread_id) if thread is None: raise LookupError(f"Thread {thread_id} not found") # Сохраняем реплику пациента до вызова LLM — чтобы она осталась в истории даже при ошибке. user_msg = Message(thread_id=thread.id, role="user", text=text) session.add(user_msg) await session.flush() retrieved = vectorstore.query(query_text=text, top_k=top_k) sources = _retrieved_to_sources(retrieved) # История для LLM: все сообщения треда, кроме только что добавленной user-реплики. stmt = ( select(Message) .where(Message.thread_id == thread.id, Message.id != user_msg.id) .order_by(Message.created_at.desc(), Message.id.desc()) .limit(HISTORY_LIMIT) ) rows = (await session.execute(stmt)).scalars().all() history = [{"role": m.role, "content": m.text} for m in reversed(rows)] llm_result = await llm.chat( question=text, sources=retrieved, history=history, temperature=temperature, max_tokens=max_tokens, ) assistant_msg = Message( thread_id=thread.id, role="assistant", text=llm_result["text"], sources_json=json.dumps(sources, ensure_ascii=False), assembled_prompt=llm_result["assembled_prompt"], ) session.add(assistant_msg) thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(assistant_msg) await session.refresh(thread) logger.info("Chat: thread=%d, user_msg=%d, assistant_msg=%d, sources=%d", thread.id, user_msg.id, assistant_msg.id, len(sources)) return { "thread_id": thread.id, "thread_name": thread.name, "message_id": assistant_msg.id, "answer": llm_result["text"], "sources": sources, "model_used": llm.model, "assembled_prompt": llm_result["assembled_prompt"], } async def list_threads(session: AsyncSession) -> list[dict]: """Список всех тредов с превью первой реплики и количеством сообщений.""" count_subq = ( select(Message.thread_id, func.count(Message.id).label("cnt")) .group_by(Message.thread_id) .subquery() ) first_msg_subq = ( select(Message.thread_id, func.min(Message.id).label("first_id")) .where(Message.role == "user") .group_by(Message.thread_id) .subquery() ) stmt = ( select( Thread, func.coalesce(count_subq.c.cnt, 0).label("messages_count"), Message.text.label("first_text"), ) .outerjoin(count_subq, count_subq.c.thread_id == Thread.id) .outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id) .outerjoin(Message, Message.id == first_msg_subq.c.first_id) .order_by(Thread.updated_at.desc()) ) rows = (await session.execute(stmt)).all() result = [] for thread, messages_count, first_text in rows: preview = (first_text or "").strip().replace("\n", " ") if len(preview) > 120: preview = preview[:120].rstrip() + "…" result.append({ "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": messages_count, "first_message_preview": preview, }) return result async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None: stmt = select(Thread).where(Thread.id == thread_id).options(selectinload(Thread.messages)) thread = (await session.execute(stmt)).scalar_one_or_none() if thread is None: return None messages = [] for m in thread.messages: sources = [] if m.sources_json: try: sources = json.loads(m.sources_json) except json.JSONDecodeError: logger.warning("Bad sources_json for message %d", m.id) messages.append({ "id": m.id, "role": m.role, "text": m.text, "created_at": m.created_at.isoformat(), "sources": sources, "assembled_prompt": m.assembled_prompt or "", }) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages": messages, } async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None: thread = await session.get(Thread, thread_id) if thread is None: return None thread.name = name thread.updated_at = datetime.now(timezone.utc) await session.commit() await session.refresh(thread) return { "id": thread.id, "name": thread.name, "created_at": thread.created_at.isoformat(), "updated_at": thread.updated_at.isoformat(), "messages_count": 0, "first_message_preview": "", } async def delete_thread(session: AsyncSession, thread_id: int) -> int | None: """Удалить тред и все его сообщения. Возвращает число удалённых сообщений или None, если треда нет.""" thread = await session.get(Thread, thread_id) if thread is None: return None count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id) messages_count = (await session.execute(count_stmt)).scalar_one() or 0 await session.execute(delete(Message).where(Message.thread_id == thread_id)) await session.delete(thread) await session.commit() return int(messages_count)