From 3c2657ab99397acc1a79402c8c73fc88f1ac0f81 Mon Sep 17 00:00:00 2001 From: AR 15 M4 Date: Thu, 23 Apr 2026 10:11:59 +0500 Subject: [PATCH] =?UTF-8?q?feat(sprint2):=20=D0=B4=D0=B8=D0=B0=D0=BB=D0=BE?= =?UTF-8?q?=D0=B3=20=D1=81=20=D0=BF=D0=B0=D0=BC=D1=8F=D1=82=D1=8C=D1=8E=20?= =?UTF-8?q?=D1=82=D1=80=D0=B5=D0=B4=D0=B0=20=E2=80=94=20POST=20/chat=20+?= =?UTF-8?q?=20CRUD=20=D1=82=D1=80=D0=B5=D0=B4=D0=BE=D0=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Второй кусок Спринта 2: агент теперь помнит контекст. RAG-retrieval делается по последней реплике пациента, в LLM уходит системный промпт + последние 20 сообщений треда + новая реплика + найденные фрагменты. Backend: - services/chat_service: send_message — создаёт тред при необходимости (auto-имя из первой реплики + UTC-дата), сохраняет user-реплику до вызова LLM (чтобы не потерять при сбое), делает retrieval, грузит историю треда (desc/limit 20 → reversed для хронологии), зовёт llm.chat, сохраняет ответ ассистента вместе с sources_json и assembled_prompt, обновляет thread.updated_at. Плюс list_threads с JOIN-выборкой превью первой реплики и счётчика сообщений, get_thread_detail через selectinload, rename_thread, delete_thread (CASCADE на FK делает уборку сообщений автоматически, но explicit delete оставлен для подсчёта удалённых). - services/llm_client.chat: принимает history=[{role, content}, ...], собирает messages = [system, ...history, user-с-RAG]; assembled_prompt дампит всю цепочку в виде [SYSTEM]/[USER]/[ASSISTANT]-блоков для отображения в Debug UI. - routers/chat: POST /chat, обрабатывает LookupError → 404. - routers/threads: GET /threads, GET /threads/{id}, PATCH /threads/{id} (переименовать), DELETE /threads/{id}. - models: ChatRequest, ThreadRenameRequest; ChatResponse, ThreadInfo, ThreadListResponse, ThreadDetailResponse, MessageInfo, ThreadDeleteResponse. Запуск: - В lifespan main.py: автоматический alembic upgrade head через asyncio.to_thread (сам alembic делает asyncio.run внутри, его нельзя звать из уже работающего event loop). LLMClient инициализируется один раз при старте — вместо создания на каждый запрос. E2E проверено: новый тред → агент отвечает и просит представиться; вторая реплика в том же треде — агент помнит контекст; PATCH переименовывает; DELETE удаляет тред с каскадом на сообщения. Co-Authored-By: Claude Opus 4.7 (1M context) --- main.py | 23 +++- models/requests.py | 12 +++ models/responses.py | 46 ++++++++ routers/chat.py | 48 +++++++++ routers/threads.py | 73 +++++++++++++ services/chat_service.py | 220 +++++++++++++++++++++++++++++++++++++++ services/llm_client.py | 70 +++++++++++++ 7 files changed, 490 insertions(+), 2 deletions(-) create mode 100644 routers/chat.py create mode 100644 routers/threads.py create mode 100644 services/chat_service.py diff --git a/main.py b/main.py index 11ada98..4ed41b5 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,39 @@ +import asyncio import logging +import os from contextlib import asynccontextmanager +from alembic import command +from alembic.config import Config as AlembicConfig from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from config import settings from services.embeddings import EmbeddingService +from services.llm_client import LLMClient from services.vectorstore import VectorStoreService logger = logging.getLogger(__name__) embedding_service: EmbeddingService | None = None vectorstore_service: VectorStoreService | None = None +llm_client: LLMClient | None = None + + +def _run_migrations() -> None: + """Автоматически подтягиваем схему до последней ревизии при старте.""" + os.makedirs(os.path.dirname(settings.sqlite_path), exist_ok=True) + cfg = AlembicConfig("alembic.ini") + command.upgrade(cfg, "head") @asynccontextmanager async def lifespan(app: FastAPI): - global embedding_service, vectorstore_service + global embedding_service, vectorstore_service, llm_client logging.basicConfig(level=getattr(logging, settings.log_level.upper(), logging.INFO)) + logger.info("Running DB migrations…") + await asyncio.to_thread(_run_migrations) logger.info("Loading embedding model: %s", settings.embedding_model) embedding_service = EmbeddingService(settings.embedding_model) logger.info("Embedding model loaded") @@ -27,6 +42,8 @@ async def lifespan(app: FastAPI): embedding_service=embedding_service, ) logger.info("ChromaDB initialized at %s", settings.chroma_persist_dir) + llm_client = LLMClient() + logger.info("LLM client ready (model=%s)", llm_client.model) yield logger.info("Shutting down") @@ -46,10 +63,12 @@ app.add_middleware( allow_headers=["*"], ) -from routers import documents, health, query # noqa: E402 +from routers import chat, documents, health, query, threads # noqa: E402 app.include_router(health.router) app.include_router(documents.router) app.include_router(query.router) +app.include_router(chat.router) +app.include_router(threads.router) app.mount("/", StaticFiles(directory="static", html=True), name="static") diff --git a/models/requests.py b/models/requests.py index afcc5d4..f4333e0 100644 --- a/models/requests.py +++ b/models/requests.py @@ -7,3 +7,15 @@ class QueryRequest(BaseModel): document_ids: list[str] | None = Field(None, description="Ограничить поиск конкретными документами") temperature: float | None = Field(None, ge=0.0, le=2.0) max_tokens: int | None = Field(None, ge=100, le=8000) + + +class ChatRequest(BaseModel): + text: str = Field(..., description="Реплика пациента") + thread_id: int | None = Field(None, description="ID треда; если не передан — создаётся новый") + top_k: int = Field(5, ge=1, le=20) + temperature: float | None = Field(None, ge=0.0, le=2.0) + max_tokens: int | None = Field(None, ge=100, le=8000) + + +class ThreadRenameRequest(BaseModel): + name: str = Field(..., min_length=1, max_length=200) diff --git a/models/responses.py b/models/responses.py index 5fc2f12..8bdc166 100644 --- a/models/responses.py +++ b/models/responses.py @@ -77,3 +77,49 @@ class HealthResponse(BaseModel): embedding_model: str documents_count: int chunks_count: int + + +class MessageInfo(BaseModel): + id: int + role: str + text: str + created_at: str + sources: list[SourceInfo] = Field(default_factory=list) + assembled_prompt: str = "" + + +class ThreadInfo(BaseModel): + id: int + name: str + created_at: str + updated_at: str + messages_count: int + first_message_preview: str = "" + + +class ThreadListResponse(BaseModel): + threads: list[ThreadInfo] + total: int + + +class ThreadDetailResponse(BaseModel): + id: int + name: str + created_at: str + updated_at: str + messages: list[MessageInfo] = Field(default_factory=list) + + +class ChatResponse(BaseModel): + thread_id: int + thread_name: str + message_id: int + answer: str + sources: list[SourceInfo] + model_used: str + assembled_prompt: str = "" + + +class ThreadDeleteResponse(BaseModel): + ok: bool = True + deleted_messages: int diff --git a/routers/chat.py b/routers/chat.py new file mode 100644 index 0000000..5e228e1 --- /dev/null +++ b/routers/chat.py @@ -0,0 +1,48 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from db.session import get_session +from models.requests import ChatRequest +from models.responses import ChatResponse, SourceInfo +from services import chat_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/chat", tags=["chat"]) + + +@router.post("", response_model=ChatResponse) +async def chat(req: ChatRequest, session: AsyncSession = Depends(get_session)): + from main import llm_client, vectorstore_service + + if vectorstore_service is None or llm_client is None: + raise HTTPException(status_code=503, detail="Service not ready") + + try: + result = await chat_service.send_message( + session=session, + vectorstore=vectorstore_service, + llm=llm_client, + text=req.text, + thread_id=req.thread_id, + top_k=req.top_k, + temperature=req.temperature, + max_tokens=req.max_tokens, + ) + except LookupError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + logger.exception("Chat failed") + raise HTTPException(status_code=500, detail=f"Chat error: {e}") + + return ChatResponse( + thread_id=result["thread_id"], + thread_name=result["thread_name"], + message_id=result["message_id"], + answer=result["answer"], + sources=[SourceInfo(**s) for s in result["sources"]], + model_used=result["model_used"], + assembled_prompt=result["assembled_prompt"], + ) diff --git a/routers/threads.py b/routers/threads.py new file mode 100644 index 0000000..9e95aaa --- /dev/null +++ b/routers/threads.py @@ -0,0 +1,73 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from db.session import get_session +from models.requests import ThreadRenameRequest +from models.responses import ( + MessageInfo, + SourceInfo, + ThreadDeleteResponse, + ThreadDetailResponse, + ThreadInfo, + ThreadListResponse, +) +from services import chat_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/threads", tags=["threads"]) + + +@router.get("", response_model=ThreadListResponse) +async def list_threads(session: AsyncSession = Depends(get_session)): + threads = await chat_service.list_threads(session) + return ThreadListResponse( + threads=[ThreadInfo(**t) for t in threads], + total=len(threads), + ) + + +@router.get("/{thread_id}", response_model=ThreadDetailResponse) +async def get_thread(thread_id: int, session: AsyncSession = Depends(get_session)): + data = await chat_service.get_thread_detail(session, thread_id) + if data is None: + raise HTTPException(status_code=404, detail="Thread not found") + return ThreadDetailResponse( + id=data["id"], + name=data["name"], + created_at=data["created_at"], + updated_at=data["updated_at"], + messages=[ + MessageInfo( + id=m["id"], + role=m["role"], + text=m["text"], + created_at=m["created_at"], + sources=[SourceInfo(**s) for s in m["sources"]], + assembled_prompt=m["assembled_prompt"], + ) + for m in data["messages"] + ], + ) + + +@router.patch("/{thread_id}", response_model=ThreadInfo) +async def rename_thread( + thread_id: int, + req: ThreadRenameRequest, + session: AsyncSession = Depends(get_session), +): + data = await chat_service.rename_thread(session, thread_id, req.name) + if data is None: + raise HTTPException(status_code=404, detail="Thread not found") + return ThreadInfo(**data) + + +@router.delete("/{thread_id}", response_model=ThreadDeleteResponse) +async def delete_thread(thread_id: int, session: AsyncSession = Depends(get_session)): + deleted = await chat_service.delete_thread(session, thread_id) + if deleted is None: + raise HTTPException(status_code=404, detail="Thread not found") + return ThreadDeleteResponse(ok=True, deleted_messages=deleted) diff --git a/services/chat_service.py b/services/chat_service.py new file mode 100644 index 0000000..f2cb837 --- /dev/null +++ b/services/chat_service.py @@ -0,0 +1,220 @@ +import json +import logging +from datetime import datetime, timezone + +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from db.models import Message, Thread +from services.llm_client import LLMClient +from services.vectorstore import VectorStoreService + +logger = logging.getLogger(__name__) + +HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM + + +def _auto_thread_name(first_user_text: str) -> str: + """Авто-имя треда: первые 60 символов первой реплики + дата.""" + preview = first_user_text.strip().replace("\n", " ") + if len(preview) > 60: + preview = preview[:60].rstrip() + "…" + stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M") + return f"{preview} · {stamp}" + + +def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]: + sources = [] + for item in retrieved: + meta = item.get("metadata", {}) + sources.append({ + "document_id": meta.get("document_id", ""), + "document_name": meta.get("document_name", ""), + "chunk_text": item["text"][:500], + "section": meta.get("section", ""), + "page": meta.get("page_number", 0), + "relevance_score": round(item.get("relevance_score", 0), 3), + }) + return sources + + +async def send_message( + session: AsyncSession, + vectorstore: VectorStoreService, + llm: LLMClient, + text: str, + thread_id: int | None = None, + top_k: int = 5, + temperature: float | None = None, + max_tokens: int | None = None, +) -> dict: + """Добавить реплику пациента в тред, получить ответ ассистента, сохранить оба сообщения.""" + if thread_id is None: + thread = Thread(name=_auto_thread_name(text)) + session.add(thread) + await session.flush() + else: + thread = await session.get(Thread, thread_id) + if thread is None: + raise LookupError(f"Thread {thread_id} not found") + + # Сохраняем реплику пациента до вызова LLM — чтобы она осталась в истории даже при ошибке. + user_msg = Message(thread_id=thread.id, role="user", text=text) + session.add(user_msg) + await session.flush() + + retrieved = vectorstore.query(query_text=text, top_k=top_k) + sources = _retrieved_to_sources(retrieved) + + # История для LLM: все сообщения треда, кроме только что добавленной user-реплики. + stmt = ( + select(Message) + .where(Message.thread_id == thread.id, Message.id != user_msg.id) + .order_by(Message.created_at.desc(), Message.id.desc()) + .limit(HISTORY_LIMIT) + ) + rows = (await session.execute(stmt)).scalars().all() + history = [{"role": m.role, "content": m.text} for m in reversed(rows)] + + llm_result = await llm.chat( + question=text, + sources=retrieved, + history=history, + temperature=temperature, + max_tokens=max_tokens, + ) + + assistant_msg = Message( + thread_id=thread.id, + role="assistant", + text=llm_result["text"], + sources_json=json.dumps(sources, ensure_ascii=False), + assembled_prompt=llm_result["assembled_prompt"], + ) + session.add(assistant_msg) + + thread.updated_at = datetime.now(timezone.utc) + + await session.commit() + await session.refresh(assistant_msg) + await session.refresh(thread) + + logger.info("Chat: thread=%d, user_msg=%d, assistant_msg=%d, sources=%d", + thread.id, user_msg.id, assistant_msg.id, len(sources)) + + return { + "thread_id": thread.id, + "thread_name": thread.name, + "message_id": assistant_msg.id, + "answer": llm_result["text"], + "sources": sources, + "model_used": llm.model, + "assembled_prompt": llm_result["assembled_prompt"], + } + + +async def list_threads(session: AsyncSession) -> list[dict]: + """Список всех тредов с превью первой реплики и количеством сообщений.""" + count_subq = ( + select(Message.thread_id, func.count(Message.id).label("cnt")) + .group_by(Message.thread_id) + .subquery() + ) + first_msg_subq = ( + select(Message.thread_id, func.min(Message.id).label("first_id")) + .where(Message.role == "user") + .group_by(Message.thread_id) + .subquery() + ) + + stmt = ( + select( + Thread, + func.coalesce(count_subq.c.cnt, 0).label("messages_count"), + Message.text.label("first_text"), + ) + .outerjoin(count_subq, count_subq.c.thread_id == Thread.id) + .outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id) + .outerjoin(Message, Message.id == first_msg_subq.c.first_id) + .order_by(Thread.updated_at.desc()) + ) + rows = (await session.execute(stmt)).all() + + result = [] + for thread, messages_count, first_text in rows: + preview = (first_text or "").strip().replace("\n", " ") + if len(preview) > 120: + preview = preview[:120].rstrip() + "…" + result.append({ + "id": thread.id, + "name": thread.name, + "created_at": thread.created_at.isoformat(), + "updated_at": thread.updated_at.isoformat(), + "messages_count": messages_count, + "first_message_preview": preview, + }) + return result + + +async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None: + stmt = select(Thread).where(Thread.id == thread_id).options(selectinload(Thread.messages)) + thread = (await session.execute(stmt)).scalar_one_or_none() + if thread is None: + return None + + messages = [] + for m in thread.messages: + sources = [] + if m.sources_json: + try: + sources = json.loads(m.sources_json) + except json.JSONDecodeError: + logger.warning("Bad sources_json for message %d", m.id) + messages.append({ + "id": m.id, + "role": m.role, + "text": m.text, + "created_at": m.created_at.isoformat(), + "sources": sources, + "assembled_prompt": m.assembled_prompt or "", + }) + return { + "id": thread.id, + "name": thread.name, + "created_at": thread.created_at.isoformat(), + "updated_at": thread.updated_at.isoformat(), + "messages": messages, + } + + +async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None: + thread = await session.get(Thread, thread_id) + if thread is None: + return None + thread.name = name + thread.updated_at = datetime.now(timezone.utc) + await session.commit() + await session.refresh(thread) + return { + "id": thread.id, + "name": thread.name, + "created_at": thread.created_at.isoformat(), + "updated_at": thread.updated_at.isoformat(), + "messages_count": 0, + "first_message_preview": "", + } + + +async def delete_thread(session: AsyncSession, thread_id: int) -> int | None: + """Удалить тред и все его сообщения. Возвращает число удалённых сообщений или None, если треда нет.""" + thread = await session.get(Thread, thread_id) + if thread is None: + return None + count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id) + messages_count = (await session.execute(count_stmt)).scalar_one() or 0 + + await session.execute(delete(Message).where(Message.thread_id == thread_id)) + await session.delete(thread) + await session.commit() + return int(messages_count) diff --git a/services/llm_client.py b/services/llm_client.py index fc7745c..f2c9ade 100644 --- a/services/llm_client.py +++ b/services/llm_client.py @@ -27,6 +27,15 @@ DEFAULT_USER_TEMPLATE = """Вопрос пациента: Ответь пациенту в чате по правилам из системного сообщения.""" +CHAT_USER_TEMPLATE = """Новая реплика пациента: +{question} + +Выдержки из базы знаний операторов (по последней реплике): +{sources} + +Ответь пациенту с учётом истории диалога выше и правил из системного сообщения.""" + + class LLMClient: def __init__( self, @@ -102,3 +111,64 @@ class LLMClient: content = data["choices"][0]["message"]["content"] logger.info("LLM response: %d chars, model=%s, temp=%.2f", len(content), self.model, effective_temp) return {"text": content.strip(), "assembled_prompt": assembled_prompt} + + async def chat( + self, + question: str, + sources: list[dict], + history: list[dict], + system_prompt: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + ) -> dict: + """Generate a patient-facing answer using RAG + conversation history. + + `history` — список предыдущих сообщений треда в формате + [{"role": "user"|"assistant", "content": str}, ...] (без текущей реплики). + + Returns dict with 'text' and 'assembled_prompt'. + """ + effective_system = system_prompt or DEFAULT_SYSTEM_PROMPT + effective_temp = temperature if temperature is not None else 0.2 + effective_max_tokens = max_tokens or 1200 + + formatted_sources = self._format_sources(sources) + user_message = CHAT_USER_TEMPLATE.format( + question=question, + sources=formatted_sources, + ) + + messages: list[dict] = [{"role": "system", "content": effective_system}] + messages.extend(history) + messages.append({"role": "user", "content": user_message}) + + assembled_prompt_parts = [f"[SYSTEM]\n{effective_system}"] + for m in history: + tag = "USER" if m["role"] == "user" else "ASSISTANT" + assembled_prompt_parts.append(f"[{tag}]\n{m['content']}") + assembled_prompt_parts.append(f"[USER]\n{user_message}") + assembled_prompt = "\n\n".join(assembled_prompt_parts) + + url = f"{self.base_url}/chat/completions" + payload = { + "model": self.model, + "messages": messages, + "temperature": effective_temp, + "max_tokens": effective_max_tokens, + } + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + url, + json=payload, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + data = response.json() + + content = data["choices"][0]["message"]["content"] + logger.info("LLM chat response: %d chars, history=%d, model=%s", len(content), len(history), self.model) + return {"text": content.strip(), "assembled_prompt": assembled_prompt}