feat(sprint2): диалог с памятью треда — POST /chat + CRUD тредов

Второй кусок Спринта 2: агент теперь помнит контекст. RAG-retrieval
делается по последней реплике пациента, в LLM уходит системный промпт +
последние 20 сообщений треда + новая реплика + найденные фрагменты.

Backend:
- services/chat_service: send_message — создаёт тред при необходимости
  (auto-имя из первой реплики + UTC-дата), сохраняет user-реплику до
  вызова LLM (чтобы не потерять при сбое), делает retrieval, грузит
  историю треда (desc/limit 20 → reversed для хронологии), зовёт
  llm.chat, сохраняет ответ ассистента вместе с sources_json и
  assembled_prompt, обновляет thread.updated_at. Плюс list_threads с
  JOIN-выборкой превью первой реплики и счётчика сообщений,
  get_thread_detail через selectinload, rename_thread, delete_thread
  (CASCADE на FK делает уборку сообщений автоматически, но
  explicit delete оставлен для подсчёта удалённых).
- services/llm_client.chat: принимает history=[{role, content}, ...],
  собирает messages = [system, ...history, user-с-RAG]; assembled_prompt
  дампит всю цепочку в виде [SYSTEM]/[USER]/[ASSISTANT]-блоков для
  отображения в Debug UI.
- routers/chat: POST /chat, обрабатывает LookupError → 404.
- routers/threads: GET /threads, GET /threads/{id}, PATCH /threads/{id}
  (переименовать), DELETE /threads/{id}.
- models: ChatRequest, ThreadRenameRequest; ChatResponse, ThreadInfo,
  ThreadListResponse, ThreadDetailResponse, MessageInfo,
  ThreadDeleteResponse.

Запуск:
- В lifespan main.py: автоматический alembic upgrade head через
  asyncio.to_thread (сам alembic делает asyncio.run внутри, его нельзя
  звать из уже работающего event loop). LLMClient инициализируется
  один раз при старте — вместо создания на каждый запрос.

E2E проверено: новый тред → агент отвечает и просит представиться;
вторая реплика в том же треде — агент помнит контекст; PATCH
переименовывает; DELETE удаляет тред с каскадом на сообщения.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
AR 15 M4
2026-04-23 10:11:59 +05:00
parent 75048bb88e
commit 3c2657ab99
7 changed files with 490 additions and 2 deletions
+21 -2
View File
@@ -1,24 +1,39 @@
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from alembic import command
from alembic.config import Config as AlembicConfig
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from config import settings
from services.embeddings import EmbeddingService
from services.llm_client import LLMClient
from services.vectorstore import VectorStoreService
logger = logging.getLogger(__name__)
embedding_service: EmbeddingService | None = None
vectorstore_service: VectorStoreService | None = None
llm_client: LLMClient | None = None
def _run_migrations() -> None:
"""Автоматически подтягиваем схему до последней ревизии при старте."""
os.makedirs(os.path.dirname(settings.sqlite_path), exist_ok=True)
cfg = AlembicConfig("alembic.ini")
command.upgrade(cfg, "head")
@asynccontextmanager
async def lifespan(app: FastAPI):
global embedding_service, vectorstore_service
global embedding_service, vectorstore_service, llm_client
logging.basicConfig(level=getattr(logging, settings.log_level.upper(), logging.INFO))
logger.info("Running DB migrations…")
await asyncio.to_thread(_run_migrations)
logger.info("Loading embedding model: %s", settings.embedding_model)
embedding_service = EmbeddingService(settings.embedding_model)
logger.info("Embedding model loaded")
@@ -27,6 +42,8 @@ async def lifespan(app: FastAPI):
embedding_service=embedding_service,
)
logger.info("ChromaDB initialized at %s", settings.chroma_persist_dir)
llm_client = LLMClient()
logger.info("LLM client ready (model=%s)", llm_client.model)
yield
logger.info("Shutting down")
@@ -46,10 +63,12 @@ app.add_middleware(
allow_headers=["*"],
)
from routers import documents, health, query # noqa: E402
from routers import chat, documents, health, query, threads # noqa: E402
app.include_router(health.router)
app.include_router(documents.router)
app.include_router(query.router)
app.include_router(chat.router)
app.include_router(threads.router)
app.mount("/", StaticFiles(directory="static", html=True), name="static")
+12
View File
@@ -7,3 +7,15 @@ class QueryRequest(BaseModel):
document_ids: list[str] | None = Field(None, description="Ограничить поиск конкретными документами")
temperature: float | None = Field(None, ge=0.0, le=2.0)
max_tokens: int | None = Field(None, ge=100, le=8000)
class ChatRequest(BaseModel):
text: str = Field(..., description="Реплика пациента")
thread_id: int | None = Field(None, description="ID треда; если не передан — создаётся новый")
top_k: int = Field(5, ge=1, le=20)
temperature: float | None = Field(None, ge=0.0, le=2.0)
max_tokens: int | None = Field(None, ge=100, le=8000)
class ThreadRenameRequest(BaseModel):
name: str = Field(..., min_length=1, max_length=200)
+46
View File
@@ -77,3 +77,49 @@ class HealthResponse(BaseModel):
embedding_model: str
documents_count: int
chunks_count: int
class MessageInfo(BaseModel):
id: int
role: str
text: str
created_at: str
sources: list[SourceInfo] = Field(default_factory=list)
assembled_prompt: str = ""
class ThreadInfo(BaseModel):
id: int
name: str
created_at: str
updated_at: str
messages_count: int
first_message_preview: str = ""
class ThreadListResponse(BaseModel):
threads: list[ThreadInfo]
total: int
class ThreadDetailResponse(BaseModel):
id: int
name: str
created_at: str
updated_at: str
messages: list[MessageInfo] = Field(default_factory=list)
class ChatResponse(BaseModel):
thread_id: int
thread_name: str
message_id: int
answer: str
sources: list[SourceInfo]
model_used: str
assembled_prompt: str = ""
class ThreadDeleteResponse(BaseModel):
ok: bool = True
deleted_messages: int
+48
View File
@@ -0,0 +1,48 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from db.session import get_session
from models.requests import ChatRequest
from models.responses import ChatResponse, SourceInfo
from services import chat_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("", response_model=ChatResponse)
async def chat(req: ChatRequest, session: AsyncSession = Depends(get_session)):
from main import llm_client, vectorstore_service
if vectorstore_service is None or llm_client is None:
raise HTTPException(status_code=503, detail="Service not ready")
try:
result = await chat_service.send_message(
session=session,
vectorstore=vectorstore_service,
llm=llm_client,
text=req.text,
thread_id=req.thread_id,
top_k=req.top_k,
temperature=req.temperature,
max_tokens=req.max_tokens,
)
except LookupError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.exception("Chat failed")
raise HTTPException(status_code=500, detail=f"Chat error: {e}")
return ChatResponse(
thread_id=result["thread_id"],
thread_name=result["thread_name"],
message_id=result["message_id"],
answer=result["answer"],
sources=[SourceInfo(**s) for s in result["sources"]],
model_used=result["model_used"],
assembled_prompt=result["assembled_prompt"],
)
+73
View File
@@ -0,0 +1,73 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from db.session import get_session
from models.requests import ThreadRenameRequest
from models.responses import (
MessageInfo,
SourceInfo,
ThreadDeleteResponse,
ThreadDetailResponse,
ThreadInfo,
ThreadListResponse,
)
from services import chat_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/threads", tags=["threads"])
@router.get("", response_model=ThreadListResponse)
async def list_threads(session: AsyncSession = Depends(get_session)):
threads = await chat_service.list_threads(session)
return ThreadListResponse(
threads=[ThreadInfo(**t) for t in threads],
total=len(threads),
)
@router.get("/{thread_id}", response_model=ThreadDetailResponse)
async def get_thread(thread_id: int, session: AsyncSession = Depends(get_session)):
data = await chat_service.get_thread_detail(session, thread_id)
if data is None:
raise HTTPException(status_code=404, detail="Thread not found")
return ThreadDetailResponse(
id=data["id"],
name=data["name"],
created_at=data["created_at"],
updated_at=data["updated_at"],
messages=[
MessageInfo(
id=m["id"],
role=m["role"],
text=m["text"],
created_at=m["created_at"],
sources=[SourceInfo(**s) for s in m["sources"]],
assembled_prompt=m["assembled_prompt"],
)
for m in data["messages"]
],
)
@router.patch("/{thread_id}", response_model=ThreadInfo)
async def rename_thread(
thread_id: int,
req: ThreadRenameRequest,
session: AsyncSession = Depends(get_session),
):
data = await chat_service.rename_thread(session, thread_id, req.name)
if data is None:
raise HTTPException(status_code=404, detail="Thread not found")
return ThreadInfo(**data)
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
async def delete_thread(thread_id: int, session: AsyncSession = Depends(get_session)):
deleted = await chat_service.delete_thread(session, thread_id)
if deleted is None:
raise HTTPException(status_code=404, detail="Thread not found")
return ThreadDeleteResponse(ok=True, deleted_messages=deleted)
+220
View File
@@ -0,0 +1,220 @@
import json
import logging
from datetime import datetime, timezone
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from db.models import Message, Thread
from services.llm_client import LLMClient
from services.vectorstore import VectorStoreService
logger = logging.getLogger(__name__)
HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM
def _auto_thread_name(first_user_text: str) -> str:
"""Авто-имя треда: первые 60 символов первой реплики + дата."""
preview = first_user_text.strip().replace("\n", " ")
if len(preview) > 60:
preview = preview[:60].rstrip() + ""
stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M")
return f"{preview} · {stamp}"
def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]:
sources = []
for item in retrieved:
meta = item.get("metadata", {})
sources.append({
"document_id": meta.get("document_id", ""),
"document_name": meta.get("document_name", ""),
"chunk_text": item["text"][:500],
"section": meta.get("section", ""),
"page": meta.get("page_number", 0),
"relevance_score": round(item.get("relevance_score", 0), 3),
})
return sources
async def send_message(
session: AsyncSession,
vectorstore: VectorStoreService,
llm: LLMClient,
text: str,
thread_id: int | None = None,
top_k: int = 5,
temperature: float | None = None,
max_tokens: int | None = None,
) -> dict:
"""Добавить реплику пациента в тред, получить ответ ассистента, сохранить оба сообщения."""
if thread_id is None:
thread = Thread(name=_auto_thread_name(text))
session.add(thread)
await session.flush()
else:
thread = await session.get(Thread, thread_id)
if thread is None:
raise LookupError(f"Thread {thread_id} not found")
# Сохраняем реплику пациента до вызова LLM — чтобы она осталась в истории даже при ошибке.
user_msg = Message(thread_id=thread.id, role="user", text=text)
session.add(user_msg)
await session.flush()
retrieved = vectorstore.query(query_text=text, top_k=top_k)
sources = _retrieved_to_sources(retrieved)
# История для LLM: все сообщения треда, кроме только что добавленной user-реплики.
stmt = (
select(Message)
.where(Message.thread_id == thread.id, Message.id != user_msg.id)
.order_by(Message.created_at.desc(), Message.id.desc())
.limit(HISTORY_LIMIT)
)
rows = (await session.execute(stmt)).scalars().all()
history = [{"role": m.role, "content": m.text} for m in reversed(rows)]
llm_result = await llm.chat(
question=text,
sources=retrieved,
history=history,
temperature=temperature,
max_tokens=max_tokens,
)
assistant_msg = Message(
thread_id=thread.id,
role="assistant",
text=llm_result["text"],
sources_json=json.dumps(sources, ensure_ascii=False),
assembled_prompt=llm_result["assembled_prompt"],
)
session.add(assistant_msg)
thread.updated_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(assistant_msg)
await session.refresh(thread)
logger.info("Chat: thread=%d, user_msg=%d, assistant_msg=%d, sources=%d",
thread.id, user_msg.id, assistant_msg.id, len(sources))
return {
"thread_id": thread.id,
"thread_name": thread.name,
"message_id": assistant_msg.id,
"answer": llm_result["text"],
"sources": sources,
"model_used": llm.model,
"assembled_prompt": llm_result["assembled_prompt"],
}
async def list_threads(session: AsyncSession) -> list[dict]:
"""Список всех тредов с превью первой реплики и количеством сообщений."""
count_subq = (
select(Message.thread_id, func.count(Message.id).label("cnt"))
.group_by(Message.thread_id)
.subquery()
)
first_msg_subq = (
select(Message.thread_id, func.min(Message.id).label("first_id"))
.where(Message.role == "user")
.group_by(Message.thread_id)
.subquery()
)
stmt = (
select(
Thread,
func.coalesce(count_subq.c.cnt, 0).label("messages_count"),
Message.text.label("first_text"),
)
.outerjoin(count_subq, count_subq.c.thread_id == Thread.id)
.outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id)
.outerjoin(Message, Message.id == first_msg_subq.c.first_id)
.order_by(Thread.updated_at.desc())
)
rows = (await session.execute(stmt)).all()
result = []
for thread, messages_count, first_text in rows:
preview = (first_text or "").strip().replace("\n", " ")
if len(preview) > 120:
preview = preview[:120].rstrip() + ""
result.append({
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages_count": messages_count,
"first_message_preview": preview,
})
return result
async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None:
stmt = select(Thread).where(Thread.id == thread_id).options(selectinload(Thread.messages))
thread = (await session.execute(stmt)).scalar_one_or_none()
if thread is None:
return None
messages = []
for m in thread.messages:
sources = []
if m.sources_json:
try:
sources = json.loads(m.sources_json)
except json.JSONDecodeError:
logger.warning("Bad sources_json for message %d", m.id)
messages.append({
"id": m.id,
"role": m.role,
"text": m.text,
"created_at": m.created_at.isoformat(),
"sources": sources,
"assembled_prompt": m.assembled_prompt or "",
})
return {
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages": messages,
}
async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None:
thread = await session.get(Thread, thread_id)
if thread is None:
return None
thread.name = name
thread.updated_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(thread)
return {
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages_count": 0,
"first_message_preview": "",
}
async def delete_thread(session: AsyncSession, thread_id: int) -> int | None:
"""Удалить тред и все его сообщения. Возвращает число удалённых сообщений или None, если треда нет."""
thread = await session.get(Thread, thread_id)
if thread is None:
return None
count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id)
messages_count = (await session.execute(count_stmt)).scalar_one() or 0
await session.execute(delete(Message).where(Message.thread_id == thread_id))
await session.delete(thread)
await session.commit()
return int(messages_count)
+70
View File
@@ -27,6 +27,15 @@ DEFAULT_USER_TEMPLATE = """Вопрос пациента:
Ответь пациенту в чате по правилам из системного сообщения."""
CHAT_USER_TEMPLATE = """Новая реплика пациента:
{question}
Выдержки из базы знаний операторов (по последней реплике):
{sources}
Ответь пациенту с учётом истории диалога выше и правил из системного сообщения."""
class LLMClient:
def __init__(
self,
@@ -102,3 +111,64 @@ class LLMClient:
content = data["choices"][0]["message"]["content"]
logger.info("LLM response: %d chars, model=%s, temp=%.2f", len(content), self.model, effective_temp)
return {"text": content.strip(), "assembled_prompt": assembled_prompt}
async def chat(
self,
question: str,
sources: list[dict],
history: list[dict],
system_prompt: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
) -> dict:
"""Generate a patient-facing answer using RAG + conversation history.
`history` — список предыдущих сообщений треда в формате
[{"role": "user"|"assistant", "content": str}, ...] (без текущей реплики).
Returns dict with 'text' and 'assembled_prompt'.
"""
effective_system = system_prompt or DEFAULT_SYSTEM_PROMPT
effective_temp = temperature if temperature is not None else 0.2
effective_max_tokens = max_tokens or 1200
formatted_sources = self._format_sources(sources)
user_message = CHAT_USER_TEMPLATE.format(
question=question,
sources=formatted_sources,
)
messages: list[dict] = [{"role": "system", "content": effective_system}]
messages.extend(history)
messages.append({"role": "user", "content": user_message})
assembled_prompt_parts = [f"[SYSTEM]\n{effective_system}"]
for m in history:
tag = "USER" if m["role"] == "user" else "ASSISTANT"
assembled_prompt_parts.append(f"[{tag}]\n{m['content']}")
assembled_prompt_parts.append(f"[USER]\n{user_message}")
assembled_prompt = "\n\n".join(assembled_prompt_parts)
url = f"{self.base_url}/chat/completions"
payload = {
"model": self.model,
"messages": messages,
"temperature": effective_temp,
"max_tokens": effective_max_tokens,
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
url,
json=payload,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
content = data["choices"][0]["message"]["content"]
logger.info("LLM chat response: %d chars, history=%d, model=%s", len(content), len(history), self.model)
return {"text": content.strip(), "assembled_prompt": assembled_prompt}