bb5e3f5eb3
Параллель к 8a, но проверяем не код intent от роутера, а содержимое ответа
конкретной ветки на одиночную реплику. Старт — general_info, 46 кейсов.
Логика pass/fail (для одного кейса):
- A — RAG-секция: среди retrieved-чанков есть кусок с
section == expected_doc_section (точное совпадение). Если поле не задано —
пропускаем.
- B — keywords: обязательные expected_keywords встречаются в predicted_answer
(case-insensitive). По умолчанию все; поддерживаются keywords_min: N
и keywords_any: true. Запрещённые expected_must_not — ни одного.
- Pass = A ∧ B. Незаданные поля не проверяются.
- Кэш: (text_hash, branch_config_id) → {answer_text, retrieved_sections}.
Привязан к версии промпта ветки. Смена версии = пустой кэш = свежий прогон.
Правка JSONL без изменения text → pass/fail пересчитывается без LLM.
Backend:
- Таблицы eval_branch_runs / eval_branch_run_cases / eval_branch_predictions.
Миграция m9g1f7e89j56.
- services/eval_branch_run_service.py: загрузка JSONL, фоновый прогон через
asyncio.create_task, кэш, оценка A+B с поддержкой keywords_min/keywords_any.
- chat_service.run_branch_single_turn — изолированный single-turn без
роутера и треда (использует существующий config_service + vectorstore + llm).
- API: POST /eval/branch-runs, GET /eval/branch-runs?intent_code=,
GET /eval/branch-runs/{id}, GET /eval/branch-cases-with-status?intent_code=.
UI (static/regression.html):
- Селектор режима «Роутер / Ветка · general_info». Логика пикера переиспользуется
(фильтры, диапазон, массовый выбор, счётчик «новые / в кэше»).
- Для режима «Ветка»: фильтр по coverage, колонки секция/coverage, keywords,
частота, кэш. Drill-down прогона: ожидание, retrieved-секции, причины fail,
полный ответ ветки.
База кейсов (eval/branch_cases_general_info.jsonl) — от пользователя, 46 кейсов
по схеме {text, intent, coverage, expected_doc_section?, expected_keywords?,
expected_must_not?, keywords_min?, keywords_any?, count?, note?}.
Связанная правка SQLite (нашли при удалении документа в этом спринте):
- db/session.py: connect-listener PRAGMA foreign_keys=ON на каждое подключение.
Без этого ondelete=CASCADE в SQLite не enforced, и удаление документа
оставляло подписки в intent_documents висячими (что давало пустой RAG
и fail регрессии).
- Миграция n0h2g8f9a0k67 — одноразовая чистка существующих висячих подписок.
docs/SPRINTS.md: Спринт 8b → ✅ Закрыт. Diff vs предыдущий прогон для веток
и кнопка «Сбросить кэш регрессии» вынесены в docs/BACKLOG.md.
Также включены обновлённые data/datasets/general_info.md и price_question.md
(рабочий материал оператора), и черновик eval/branch_cases_price_question.jsonl
для следующего захода (8b на price_question).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
332 lines
13 KiB
Python
332 lines
13 KiB
Python
"""Регрессия ответов веток в UI (Спринт 8b).
|
||
|
||
Параллельный сервис к `eval_run_service` (роутер): здесь оператор проверяет
|
||
содержимое ответа конкретной ветки. На старте 8b — только `general_info`,
|
||
но архитектура не привязана к коду ветки: добавление новой = положить
|
||
`eval/branch_cases_<code>.jsonl`.
|
||
|
||
Pass/fail для одного кейса:
|
||
- **A (RAG-секция):** среди retrieved-чанков есть кусок с
|
||
`section == expected_doc_section`. Если ожидание не задано — пропускаем.
|
||
- **B (keywords):** в `predicted_answer` встречаются обязательные подстроки
|
||
(с учётом `keywords_min` или `keywords_any`) и нет запрещённых
|
||
(`expected_must_not`). Сравнение case-insensitive.
|
||
- Pass = A ∧ B; failed_reasons собирает короткие причины для UI.
|
||
|
||
Кэш: `(text_hash, branch_config_id) → {answer_text, retrieved_sections}`.
|
||
Привязан к версии активного промпта ветки. Смена версии = свежий прогон.
|
||
"""
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from db.models import (
|
||
AgentConfig,
|
||
EvalBranchPrediction,
|
||
EvalBranchRun,
|
||
EvalBranchRunCase,
|
||
)
|
||
from db.session import SessionLocal
|
||
from services import chat_service, config_service
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
EVAL_DIR = Path(__file__).resolve().parent.parent / "eval"
|
||
|
||
|
||
def _branch_cases_filename(intent_code: str) -> str:
|
||
return f"branch_cases_{intent_code}.jsonl"
|
||
|
||
|
||
def _text_hash(text: str) -> str:
|
||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||
|
||
|
||
@dataclass
|
||
class BranchCase:
|
||
text: str
|
||
intent_code: str
|
||
coverage: str = "covered"
|
||
expected_doc_section: str | None = None
|
||
expected_keywords: list[str] = field(default_factory=list)
|
||
expected_must_not: list[str] = field(default_factory=list)
|
||
keywords_min: int | None = None # если задан — нужно совпадение ≥ N keywords
|
||
keywords_any: bool = False # alias для keywords_min=1
|
||
count: int = 1
|
||
note: str | None = None
|
||
|
||
def required_keyword_count(self) -> int:
|
||
"""Сколько keywords минимум должны встретиться в ответе."""
|
||
total = len(self.expected_keywords)
|
||
if total == 0:
|
||
return 0
|
||
if self.keywords_min is not None:
|
||
return max(1, min(self.keywords_min, total))
|
||
if self.keywords_any:
|
||
return 1
|
||
return total # дефолт — все обязательны
|
||
|
||
|
||
def load_branch_cases(intent_code: str) -> list[BranchCase]:
|
||
"""Прочитать JSONL для ветки. Если файл отсутствует — пустой список + warning."""
|
||
path = EVAL_DIR / _branch_cases_filename(intent_code)
|
||
if not path.exists():
|
||
logger.warning("Branch cases file not found: %s", path)
|
||
return []
|
||
cases: list[BranchCase] = []
|
||
with path.open(encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
obj = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
logger.warning("Bad JSONL line in %s: %r", path.name, line[:120])
|
||
continue
|
||
cases.append(BranchCase(
|
||
text=str(obj["text"]),
|
||
intent_code=str(obj.get("intent", intent_code)),
|
||
coverage=str(obj.get("coverage", "covered")),
|
||
expected_doc_section=obj.get("expected_doc_section"),
|
||
expected_keywords=list(obj.get("expected_keywords") or []),
|
||
expected_must_not=list(obj.get("expected_must_not") or []),
|
||
keywords_min=obj.get("keywords_min"),
|
||
keywords_any=bool(obj.get("keywords_any", False)),
|
||
count=int(obj.get("count", 1)),
|
||
note=obj.get("note"),
|
||
))
|
||
cases.sort(key=lambda c: (-c.count, c.text))
|
||
return cases
|
||
|
||
|
||
async def _resolve_active_branch_config_id(
|
||
session: AsyncSession, intent_code: str
|
||
) -> int | None:
|
||
pair = await config_service.get_active_config_by_intent_code(session, intent_code)
|
||
if pair is None:
|
||
return None
|
||
_, cfg = pair
|
||
return cfg.id
|
||
|
||
|
||
async def cached_predictions(
|
||
session: AsyncSession, branch_config_id: int | None
|
||
) -> dict[str, dict]:
|
||
"""{ text_hash → {answer_text, retrieved_sections} } для активной версии."""
|
||
rows = (await session.execute(
|
||
select(
|
||
EvalBranchPrediction.text_hash,
|
||
EvalBranchPrediction.answer_text,
|
||
EvalBranchPrediction.retrieved_sections_json,
|
||
).where(EvalBranchPrediction.branch_config_id == branch_config_id)
|
||
)).all()
|
||
out: dict[str, dict] = {}
|
||
for th, answer, sections_json in rows:
|
||
try:
|
||
sections = json.loads(sections_json) if sections_json else []
|
||
except json.JSONDecodeError:
|
||
sections = []
|
||
out[th] = {"answer_text": answer or "", "retrieved_sections": sections}
|
||
return out
|
||
|
||
|
||
def _evaluate_case(
|
||
case: BranchCase, answer_text: str, retrieved_sections: list[dict]
|
||
) -> tuple[bool, list[str]]:
|
||
"""Возвращает (is_pass, fail_reasons)."""
|
||
reasons: list[str] = []
|
||
|
||
# A. RAG-секция.
|
||
if case.expected_doc_section:
|
||
sections_in_retrieved = {s.get("section", "") for s in retrieved_sections}
|
||
if case.expected_doc_section not in sections_in_retrieved:
|
||
reasons.append(f"section не найдена: {case.expected_doc_section!r}")
|
||
|
||
# B. keywords.
|
||
text_lower = (answer_text or "").lower()
|
||
if case.expected_keywords:
|
||
hits = [kw for kw in case.expected_keywords if kw.lower() in text_lower]
|
||
need = case.required_keyword_count()
|
||
if len(hits) < need:
|
||
missing = [kw for kw in case.expected_keywords if kw.lower() not in text_lower]
|
||
reasons.append(
|
||
f"keywords: совпало {len(hits)}/{len(case.expected_keywords)}, нужно {need}; "
|
||
f"не нашлись: {missing[:5]}"
|
||
)
|
||
|
||
# B. must_not.
|
||
if case.expected_must_not:
|
||
bad = [kw for kw in case.expected_must_not if kw.lower() in text_lower]
|
||
if bad:
|
||
reasons.append(f"в ответе есть запрещённое: {bad}")
|
||
|
||
return (len(reasons) == 0), reasons
|
||
|
||
|
||
async def start_branch_run(
|
||
session: AsyncSession, intent_code: str, text_hashes: list[str]
|
||
) -> EvalBranchRun:
|
||
"""Создаёт run в running и стартует фоновую корутину."""
|
||
if not text_hashes:
|
||
raise ValueError("text_hashes is empty")
|
||
branch_config_id = await _resolve_active_branch_config_id(session, intent_code)
|
||
all_cases = load_branch_cases(intent_code)
|
||
wanted = set(text_hashes)
|
||
cases = [c for c in all_cases if _text_hash(c.text) in wanted]
|
||
run = EvalBranchRun(
|
||
suite=f"branch:{intent_code}",
|
||
intent_code=intent_code,
|
||
branch_config_id=branch_config_id,
|
||
status="running",
|
||
total=len(cases),
|
||
)
|
||
session.add(run)
|
||
await session.commit()
|
||
await session.refresh(run)
|
||
asyncio.create_task(_run_branch_suite(run.id, intent_code, branch_config_id, cases))
|
||
return run
|
||
|
||
|
||
async def _run_branch_suite(
|
||
run_id: int,
|
||
intent_code: str,
|
||
branch_config_id: int | None,
|
||
cases: list[BranchCase],
|
||
) -> None:
|
||
"""Фоновая корутина: своя сессия, не объекты от вызывающего."""
|
||
# Импорт vectorstore + llm singletons из main по требованию: модуль грузится
|
||
# после lifespan, ссылки уже инициализированы.
|
||
import main as _main
|
||
|
||
passed = failed = cache_hits = 0
|
||
try:
|
||
async with SessionLocal() as session:
|
||
run = await session.get(EvalBranchRun, run_id)
|
||
if run is None:
|
||
logger.error("eval_branch_run %d disappeared before start", run_id)
|
||
return
|
||
for case in cases:
|
||
th = _text_hash(case.text)
|
||
cached = (await session.execute(
|
||
select(EvalBranchPrediction).where(
|
||
EvalBranchPrediction.text_hash == th,
|
||
EvalBranchPrediction.branch_config_id == branch_config_id,
|
||
)
|
||
)).scalar_one_or_none()
|
||
|
||
if cached is not None:
|
||
answer_text = cached.answer_text
|
||
try:
|
||
retrieved_sections = json.loads(cached.retrieved_sections_json or "[]")
|
||
except json.JSONDecodeError:
|
||
retrieved_sections = []
|
||
cache_hits += 1
|
||
else:
|
||
try:
|
||
result = await chat_service.run_branch_single_turn(
|
||
session=session,
|
||
vectorstore=_main.vectorstore_service,
|
||
llm=_main.llm_client,
|
||
intent_code=intent_code,
|
||
text=case.text,
|
||
)
|
||
answer_text = result["answer_text"]
|
||
retrieved_sections = result["retrieved_sections"]
|
||
except Exception as e:
|
||
logger.warning(
|
||
"branch single-turn failed for case %r: %s",
|
||
case.text[:60], e,
|
||
)
|
||
answer_text = ""
|
||
retrieved_sections = []
|
||
session.add(EvalBranchPrediction(
|
||
text_hash=th,
|
||
branch_config_id=branch_config_id,
|
||
answer_text=answer_text,
|
||
retrieved_sections_json=json.dumps(retrieved_sections, ensure_ascii=False),
|
||
))
|
||
|
||
is_pass, reasons = _evaluate_case(case, answer_text, retrieved_sections)
|
||
if is_pass:
|
||
passed += 1
|
||
else:
|
||
failed += 1
|
||
session.add(EvalBranchRunCase(
|
||
run_id=run_id,
|
||
text=case.text,
|
||
coverage=case.coverage,
|
||
expected_doc_section=case.expected_doc_section,
|
||
expected_keywords_json=json.dumps(case.expected_keywords, ensure_ascii=False),
|
||
expected_must_not_json=json.dumps(case.expected_must_not, ensure_ascii=False),
|
||
keywords_min=case.keywords_min if case.keywords_min is not None
|
||
else (1 if case.keywords_any else None),
|
||
predicted_answer=answer_text,
|
||
predicted_sections_json=json.dumps(retrieved_sections, ensure_ascii=False),
|
||
is_pass=is_pass,
|
||
fail_reasons_json=json.dumps(reasons, ensure_ascii=False),
|
||
count_weight=case.count,
|
||
))
|
||
|
||
if (passed + failed) % 10 == 0:
|
||
run.passed = passed
|
||
run.failed = failed
|
||
run.cache_hits = cache_hits
|
||
await session.commit()
|
||
|
||
run.passed = passed
|
||
run.failed = failed
|
||
run.cache_hits = cache_hits
|
||
run.status = "done"
|
||
run.finished_at = datetime.now(timezone.utc)
|
||
await session.commit()
|
||
logger.info(
|
||
"eval_branch_run %d done: total=%d passed=%d failed=%d cache_hits=%d",
|
||
run_id, len(cases), passed, failed, cache_hits,
|
||
)
|
||
except Exception as e:
|
||
logger.exception("eval_branch_run %d failed: %s", run_id, e)
|
||
try:
|
||
async with SessionLocal() as session:
|
||
run = await session.get(EvalBranchRun, run_id)
|
||
if run is not None:
|
||
run.status = "error"
|
||
run.error_text = f"{type(e).__name__}: {e}"
|
||
run.finished_at = datetime.now(timezone.utc)
|
||
await session.commit()
|
||
except Exception:
|
||
logger.exception("Failed to mark eval_branch_run %d as error", run_id)
|
||
|
||
|
||
async def list_runs(
|
||
session: AsyncSession, intent_code: str | None = None, limit: int = 50
|
||
) -> list[EvalBranchRun]:
|
||
stmt = select(EvalBranchRun).order_by(EvalBranchRun.id.desc()).limit(limit)
|
||
if intent_code:
|
||
stmt = stmt.where(EvalBranchRun.intent_code == intent_code)
|
||
return list((await session.execute(stmt)).scalars().all())
|
||
|
||
|
||
async def get_run(session: AsyncSession, run_id: int) -> EvalBranchRun | None:
|
||
return await session.get(EvalBranchRun, run_id)
|
||
|
||
|
||
async def list_run_cases(session: AsyncSession, run_id: int) -> list[EvalBranchRunCase]:
|
||
stmt = (
|
||
select(EvalBranchRunCase)
|
||
.where(EvalBranchRunCase.run_id == run_id)
|
||
.order_by(
|
||
EvalBranchRunCase.is_pass, # сначала fail
|
||
EvalBranchRunCase.count_weight.desc(),
|
||
EvalBranchRunCase.id,
|
||
)
|
||
)
|
||
return list((await session.execute(stmt)).scalars().all())
|