Files
RAG_helper/services/chat_service.py
T
AR 15 M4 3c2657ab99 feat(sprint2): диалог с памятью треда — POST /chat + CRUD тредов
Второй кусок Спринта 2: агент теперь помнит контекст. RAG-retrieval
делается по последней реплике пациента, в LLM уходит системный промпт +
последние 20 сообщений треда + новая реплика + найденные фрагменты.

Backend:
- services/chat_service: send_message — создаёт тред при необходимости
  (auto-имя из первой реплики + UTC-дата), сохраняет user-реплику до
  вызова LLM (чтобы не потерять при сбое), делает retrieval, грузит
  историю треда (desc/limit 20 → reversed для хронологии), зовёт
  llm.chat, сохраняет ответ ассистента вместе с sources_json и
  assembled_prompt, обновляет thread.updated_at. Плюс list_threads с
  JOIN-выборкой превью первой реплики и счётчика сообщений,
  get_thread_detail через selectinload, rename_thread, delete_thread
  (CASCADE на FK делает уборку сообщений автоматически, но
  explicit delete оставлен для подсчёта удалённых).
- services/llm_client.chat: принимает history=[{role, content}, ...],
  собирает messages = [system, ...history, user-с-RAG]; assembled_prompt
  дампит всю цепочку в виде [SYSTEM]/[USER]/[ASSISTANT]-блоков для
  отображения в Debug UI.
- routers/chat: POST /chat, обрабатывает LookupError → 404.
- routers/threads: GET /threads, GET /threads/{id}, PATCH /threads/{id}
  (переименовать), DELETE /threads/{id}.
- models: ChatRequest, ThreadRenameRequest; ChatResponse, ThreadInfo,
  ThreadListResponse, ThreadDetailResponse, MessageInfo,
  ThreadDeleteResponse.

Запуск:
- В lifespan main.py: автоматический alembic upgrade head через
  asyncio.to_thread (сам alembic делает asyncio.run внутри, его нельзя
  звать из уже работающего event loop). LLMClient инициализируется
  один раз при старте — вместо создания на каждый запрос.

E2E проверено: новый тред → агент отвечает и просит представиться;
вторая реплика в том же треде — агент помнит контекст; PATCH
переименовывает; DELETE удаляет тред с каскадом на сообщения.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-23 10:11:59 +05:00

221 lines
7.8 KiB
Python

import json
import logging
from datetime import datetime, timezone
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from db.models import Message, Thread
from services.llm_client import LLMClient
from services.vectorstore import VectorStoreService
logger = logging.getLogger(__name__)
HISTORY_LIMIT = 20 # последние N сообщений треда, которые улетают в LLM
def _auto_thread_name(first_user_text: str) -> str:
"""Авто-имя треда: первые 60 символов первой реплики + дата."""
preview = first_user_text.strip().replace("\n", " ")
if len(preview) > 60:
preview = preview[:60].rstrip() + ""
stamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M")
return f"{preview} · {stamp}"
def _retrieved_to_sources(retrieved: list[dict]) -> list[dict]:
sources = []
for item in retrieved:
meta = item.get("metadata", {})
sources.append({
"document_id": meta.get("document_id", ""),
"document_name": meta.get("document_name", ""),
"chunk_text": item["text"][:500],
"section": meta.get("section", ""),
"page": meta.get("page_number", 0),
"relevance_score": round(item.get("relevance_score", 0), 3),
})
return sources
async def send_message(
session: AsyncSession,
vectorstore: VectorStoreService,
llm: LLMClient,
text: str,
thread_id: int | None = None,
top_k: int = 5,
temperature: float | None = None,
max_tokens: int | None = None,
) -> dict:
"""Добавить реплику пациента в тред, получить ответ ассистента, сохранить оба сообщения."""
if thread_id is None:
thread = Thread(name=_auto_thread_name(text))
session.add(thread)
await session.flush()
else:
thread = await session.get(Thread, thread_id)
if thread is None:
raise LookupError(f"Thread {thread_id} not found")
# Сохраняем реплику пациента до вызова LLM — чтобы она осталась в истории даже при ошибке.
user_msg = Message(thread_id=thread.id, role="user", text=text)
session.add(user_msg)
await session.flush()
retrieved = vectorstore.query(query_text=text, top_k=top_k)
sources = _retrieved_to_sources(retrieved)
# История для LLM: все сообщения треда, кроме только что добавленной user-реплики.
stmt = (
select(Message)
.where(Message.thread_id == thread.id, Message.id != user_msg.id)
.order_by(Message.created_at.desc(), Message.id.desc())
.limit(HISTORY_LIMIT)
)
rows = (await session.execute(stmt)).scalars().all()
history = [{"role": m.role, "content": m.text} for m in reversed(rows)]
llm_result = await llm.chat(
question=text,
sources=retrieved,
history=history,
temperature=temperature,
max_tokens=max_tokens,
)
assistant_msg = Message(
thread_id=thread.id,
role="assistant",
text=llm_result["text"],
sources_json=json.dumps(sources, ensure_ascii=False),
assembled_prompt=llm_result["assembled_prompt"],
)
session.add(assistant_msg)
thread.updated_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(assistant_msg)
await session.refresh(thread)
logger.info("Chat: thread=%d, user_msg=%d, assistant_msg=%d, sources=%d",
thread.id, user_msg.id, assistant_msg.id, len(sources))
return {
"thread_id": thread.id,
"thread_name": thread.name,
"message_id": assistant_msg.id,
"answer": llm_result["text"],
"sources": sources,
"model_used": llm.model,
"assembled_prompt": llm_result["assembled_prompt"],
}
async def list_threads(session: AsyncSession) -> list[dict]:
"""Список всех тредов с превью первой реплики и количеством сообщений."""
count_subq = (
select(Message.thread_id, func.count(Message.id).label("cnt"))
.group_by(Message.thread_id)
.subquery()
)
first_msg_subq = (
select(Message.thread_id, func.min(Message.id).label("first_id"))
.where(Message.role == "user")
.group_by(Message.thread_id)
.subquery()
)
stmt = (
select(
Thread,
func.coalesce(count_subq.c.cnt, 0).label("messages_count"),
Message.text.label("first_text"),
)
.outerjoin(count_subq, count_subq.c.thread_id == Thread.id)
.outerjoin(first_msg_subq, first_msg_subq.c.thread_id == Thread.id)
.outerjoin(Message, Message.id == first_msg_subq.c.first_id)
.order_by(Thread.updated_at.desc())
)
rows = (await session.execute(stmt)).all()
result = []
for thread, messages_count, first_text in rows:
preview = (first_text or "").strip().replace("\n", " ")
if len(preview) > 120:
preview = preview[:120].rstrip() + ""
result.append({
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages_count": messages_count,
"first_message_preview": preview,
})
return result
async def get_thread_detail(session: AsyncSession, thread_id: int) -> dict | None:
stmt = select(Thread).where(Thread.id == thread_id).options(selectinload(Thread.messages))
thread = (await session.execute(stmt)).scalar_one_or_none()
if thread is None:
return None
messages = []
for m in thread.messages:
sources = []
if m.sources_json:
try:
sources = json.loads(m.sources_json)
except json.JSONDecodeError:
logger.warning("Bad sources_json for message %d", m.id)
messages.append({
"id": m.id,
"role": m.role,
"text": m.text,
"created_at": m.created_at.isoformat(),
"sources": sources,
"assembled_prompt": m.assembled_prompt or "",
})
return {
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages": messages,
}
async def rename_thread(session: AsyncSession, thread_id: int, name: str) -> dict | None:
thread = await session.get(Thread, thread_id)
if thread is None:
return None
thread.name = name
thread.updated_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(thread)
return {
"id": thread.id,
"name": thread.name,
"created_at": thread.created_at.isoformat(),
"updated_at": thread.updated_at.isoformat(),
"messages_count": 0,
"first_message_preview": "",
}
async def delete_thread(session: AsyncSession, thread_id: int) -> int | None:
"""Удалить тред и все его сообщения. Возвращает число удалённых сообщений или None, если треда нет."""
thread = await session.get(Thread, thread_id)
if thread is None:
return None
count_stmt = select(func.count(Message.id)).where(Message.thread_id == thread_id)
messages_count = (await session.execute(count_stmt)).scalar_one() or 0
await session.execute(delete(Message).where(Message.thread_id == thread_id))
await session.delete(thread)
await session.commit()
return int(messages_count)