"""Регрессия ответов веток в UI (Спринт 8b). Параллельный сервис к `eval_run_service` (роутер): здесь оператор проверяет содержимое ответа конкретной ветки. На старте 8b — только `general_info`, но архитектура не привязана к коду ветки: добавление новой = положить `eval/branch_cases_.jsonl`. Pass/fail для одного кейса: - **A (RAG-секция):** среди retrieved-чанков есть кусок с `section == expected_doc_section`. Если ожидание не задано — пропускаем. - **B (keywords):** в `predicted_answer` встречаются обязательные подстроки (с учётом `keywords_min` или `keywords_any`) и нет запрещённых (`expected_must_not`). Сравнение case-insensitive. - Pass = A ∧ B; failed_reasons собирает короткие причины для UI. Кэш: `(text_hash, branch_config_id) → {answer_text, retrieved_sections}`. Привязан к версии активного промпта ветки. Смена версии = свежий прогон. """ import asyncio import hashlib import json import logging from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from db.models import ( AgentConfig, EvalBranchPrediction, EvalBranchRun, EvalBranchRunCase, ) from db.session import SessionLocal from services import chat_service, config_service logger = logging.getLogger(__name__) EVAL_DIR = Path(__file__).resolve().parent.parent / "eval" def _branch_cases_filename(intent_code: str) -> str: return f"branch_cases_{intent_code}.jsonl" def _text_hash(text: str) -> str: return hashlib.sha256(text.encode("utf-8")).hexdigest() @dataclass class BranchCase: text: str intent_code: str coverage: str = "covered" expected_doc_section: str | None = None expected_keywords: list[str] = field(default_factory=list) expected_must_not: list[str] = field(default_factory=list) keywords_min: int | None = None # если задан — нужно совпадение ≥ N keywords keywords_any: bool = False # alias для keywords_min=1 count: int = 1 note: str | None = None def required_keyword_count(self) -> int: """Сколько keywords минимум должны встретиться в ответе.""" total = len(self.expected_keywords) if total == 0: return 0 if self.keywords_min is not None: return max(1, min(self.keywords_min, total)) if self.keywords_any: return 1 return total # дефолт — все обязательны def load_branch_cases(intent_code: str) -> list[BranchCase]: """Прочитать JSONL для ветки. Если файл отсутствует — пустой список + warning.""" path = EVAL_DIR / _branch_cases_filename(intent_code) if not path.exists(): logger.warning("Branch cases file not found: %s", path) return [] cases: list[BranchCase] = [] with path.open(encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: obj = json.loads(line) except json.JSONDecodeError: logger.warning("Bad JSONL line in %s: %r", path.name, line[:120]) continue cases.append(BranchCase( text=str(obj["text"]), intent_code=str(obj.get("intent", intent_code)), coverage=str(obj.get("coverage", "covered")), expected_doc_section=obj.get("expected_doc_section"), expected_keywords=list(obj.get("expected_keywords") or []), expected_must_not=list(obj.get("expected_must_not") or []), keywords_min=obj.get("keywords_min"), keywords_any=bool(obj.get("keywords_any", False)), count=int(obj.get("count", 1)), note=obj.get("note"), )) cases.sort(key=lambda c: (-c.count, c.text)) return cases async def _resolve_active_branch_config_id( session: AsyncSession, intent_code: str ) -> int | None: pair = await config_service.get_active_config_by_intent_code(session, intent_code) if pair is None: return None _, cfg = pair return cfg.id async def cached_predictions( session: AsyncSession, branch_config_id: int | None ) -> dict[str, dict]: """{ text_hash → {answer_text, retrieved_sections} } для активной версии.""" rows = (await session.execute( select( EvalBranchPrediction.text_hash, EvalBranchPrediction.answer_text, EvalBranchPrediction.retrieved_sections_json, ).where(EvalBranchPrediction.branch_config_id == branch_config_id) )).all() out: dict[str, dict] = {} for th, answer, sections_json in rows: try: sections = json.loads(sections_json) if sections_json else [] except json.JSONDecodeError: sections = [] out[th] = {"answer_text": answer or "", "retrieved_sections": sections} return out def _evaluate_case( case: BranchCase, answer_text: str, retrieved_sections: list[dict] ) -> tuple[bool, list[str]]: """Возвращает (is_pass, fail_reasons).""" reasons: list[str] = [] # A. RAG-секция. if case.expected_doc_section: sections_in_retrieved = {s.get("section", "") for s in retrieved_sections} if case.expected_doc_section not in sections_in_retrieved: reasons.append(f"section не найдена: {case.expected_doc_section!r}") # B. keywords. text_lower = (answer_text or "").lower() if case.expected_keywords: hits = [kw for kw in case.expected_keywords if kw.lower() in text_lower] need = case.required_keyword_count() if len(hits) < need: missing = [kw for kw in case.expected_keywords if kw.lower() not in text_lower] reasons.append( f"keywords: совпало {len(hits)}/{len(case.expected_keywords)}, нужно {need}; " f"не нашлись: {missing[:5]}" ) # B. must_not. if case.expected_must_not: bad = [kw for kw in case.expected_must_not if kw.lower() in text_lower] if bad: reasons.append(f"в ответе есть запрещённое: {bad}") return (len(reasons) == 0), reasons async def start_branch_run( session: AsyncSession, intent_code: str, text_hashes: list[str] ) -> EvalBranchRun: """Создаёт run в running и стартует фоновую корутину.""" if not text_hashes: raise ValueError("text_hashes is empty") branch_config_id = await _resolve_active_branch_config_id(session, intent_code) all_cases = load_branch_cases(intent_code) wanted = set(text_hashes) cases = [c for c in all_cases if _text_hash(c.text) in wanted] run = EvalBranchRun( suite=f"branch:{intent_code}", intent_code=intent_code, branch_config_id=branch_config_id, status="running", total=len(cases), ) session.add(run) await session.commit() await session.refresh(run) asyncio.create_task(_run_branch_suite(run.id, intent_code, branch_config_id, cases)) return run async def _run_branch_suite( run_id: int, intent_code: str, branch_config_id: int | None, cases: list[BranchCase], ) -> None: """Фоновая корутина: своя сессия, не объекты от вызывающего.""" # Импорт vectorstore + llm singletons из main по требованию: модуль грузится # после lifespan, ссылки уже инициализированы. import main as _main passed = failed = cache_hits = 0 try: async with SessionLocal() as session: run = await session.get(EvalBranchRun, run_id) if run is None: logger.error("eval_branch_run %d disappeared before start", run_id) return for case in cases: th = _text_hash(case.text) cached = (await session.execute( select(EvalBranchPrediction).where( EvalBranchPrediction.text_hash == th, EvalBranchPrediction.branch_config_id == branch_config_id, ) )).scalar_one_or_none() if cached is not None: answer_text = cached.answer_text try: retrieved_sections = json.loads(cached.retrieved_sections_json or "[]") except json.JSONDecodeError: retrieved_sections = [] cache_hits += 1 else: try: result = await chat_service.run_branch_single_turn( session=session, vectorstore=_main.vectorstore_service, llm=_main.llm_client, intent_code=intent_code, text=case.text, ) answer_text = result["answer_text"] retrieved_sections = result["retrieved_sections"] except Exception as e: logger.warning( "branch single-turn failed for case %r: %s", case.text[:60], e, ) answer_text = "" retrieved_sections = [] session.add(EvalBranchPrediction( text_hash=th, branch_config_id=branch_config_id, answer_text=answer_text, retrieved_sections_json=json.dumps(retrieved_sections, ensure_ascii=False), )) is_pass, reasons = _evaluate_case(case, answer_text, retrieved_sections) if is_pass: passed += 1 else: failed += 1 session.add(EvalBranchRunCase( run_id=run_id, text=case.text, coverage=case.coverage, expected_doc_section=case.expected_doc_section, expected_keywords_json=json.dumps(case.expected_keywords, ensure_ascii=False), expected_must_not_json=json.dumps(case.expected_must_not, ensure_ascii=False), keywords_min=case.keywords_min if case.keywords_min is not None else (1 if case.keywords_any else None), predicted_answer=answer_text, predicted_sections_json=json.dumps(retrieved_sections, ensure_ascii=False), is_pass=is_pass, fail_reasons_json=json.dumps(reasons, ensure_ascii=False), count_weight=case.count, )) if (passed + failed) % 10 == 0: run.passed = passed run.failed = failed run.cache_hits = cache_hits await session.commit() run.passed = passed run.failed = failed run.cache_hits = cache_hits run.status = "done" run.finished_at = datetime.now(timezone.utc) await session.commit() logger.info( "eval_branch_run %d done: total=%d passed=%d failed=%d cache_hits=%d", run_id, len(cases), passed, failed, cache_hits, ) except Exception as e: logger.exception("eval_branch_run %d failed: %s", run_id, e) try: async with SessionLocal() as session: run = await session.get(EvalBranchRun, run_id) if run is not None: run.status = "error" run.error_text = f"{type(e).__name__}: {e}" run.finished_at = datetime.now(timezone.utc) await session.commit() except Exception: logger.exception("Failed to mark eval_branch_run %d as error", run_id) async def list_runs( session: AsyncSession, intent_code: str | None = None, limit: int = 50 ) -> list[EvalBranchRun]: stmt = select(EvalBranchRun).order_by(EvalBranchRun.id.desc()).limit(limit) if intent_code: stmt = stmt.where(EvalBranchRun.intent_code == intent_code) return list((await session.execute(stmt)).scalars().all()) async def get_run(session: AsyncSession, run_id: int) -> EvalBranchRun | None: return await session.get(EvalBranchRun, run_id) async def list_run_cases(session: AsyncSession, run_id: int) -> list[EvalBranchRunCase]: stmt = ( select(EvalBranchRunCase) .where(EvalBranchRunCase.run_id == run_id) .order_by( EvalBranchRunCase.is_pass, # сначала fail EvalBranchRunCase.count_weight.desc(), EvalBranchRunCase.id, ) ) return list((await session.execute(stmt)).scalars().all())