"""State machine треда: текущая ветка, шаг внутри ветки, собранные слоты. Используется в chat_service для ведения многошаговых сценариев (Спринт 5). Слоты — произвольный JSON-словарь, конкретные ключи определяются веткой. """ import json import logging from datetime import datetime, timezone from sqlalchemy.ext.asyncio import AsyncSession from db.models import ThreadState logger = logging.getLogger(__name__) async def get_state(session: AsyncSession, thread_id: int) -> ThreadState | None: return await session.get(ThreadState, thread_id) def _parse_slots(raw: str) -> dict: if not raw: return {} try: value = json.loads(raw) except json.JSONDecodeError: logger.warning("Bad slots_json for thread_state, resetting to {}") return {} return value if isinstance(value, dict) else {} async def load_snapshot(session: AsyncSession, thread_id: int) -> dict: """Снимок состояния диалога: текущая ветка/шаг/слоты + handoff_count + suspended_*.""" state = await get_state(session, thread_id) if state is None: return { "current_intent_code": None, "current_step": 0, "current_step_code": None, "slots": {}, "handoff_count": 0, "soft_insertion_count": 0, "suspended_intent": None, "resumable_step_code": None, "resumable_slots": {}, } resumable_slots = {} if state.resumable_slots_json: try: value = json.loads(state.resumable_slots_json) if isinstance(value, dict): resumable_slots = value except json.JSONDecodeError: logger.warning("Bad resumable_slots_json for thread_state, ignoring") return { "current_intent_code": state.current_intent_code, "current_step": state.current_step, "current_step_code": state.current_step_code, "slots": _parse_slots(state.slots_json), "handoff_count": state.handoff_count, "soft_insertion_count": state.soft_insertion_count, "suspended_intent": state.suspended_intent, "resumable_step_code": state.resumable_step_code, "resumable_slots": resumable_slots, } async def upsert( session: AsyncSession, thread_id: int, *, intent_code: str | None, step: int, slots: dict, step_code: str | None = None, handoff_count: int = 0, soft_insertion_count: int = 0, suspended_intent: str | None = None, resumable_step_code: str | None = None, resumable_slots: dict | None = None, ) -> ThreadState: """Создать или обновить состояние треда. Коммит — на совести вызывающего.""" state = await get_state(session, thread_id) now = datetime.now(timezone.utc) slots_raw = json.dumps(slots or {}, ensure_ascii=False) resumable_raw = ( json.dumps(resumable_slots, ensure_ascii=False) if resumable_slots is not None and len(resumable_slots) > 0 else None ) if state is None: state = ThreadState( thread_id=thread_id, current_intent_code=intent_code, current_step=step, current_step_code=step_code, slots_json=slots_raw, handoff_count=handoff_count, soft_insertion_count=soft_insertion_count, suspended_intent=suspended_intent, resumable_step_code=resumable_step_code, resumable_slots_json=resumable_raw, updated_at=now, ) session.add(state) else: state.current_intent_code = intent_code state.current_step = step state.current_step_code = step_code state.slots_json = slots_raw state.handoff_count = handoff_count state.soft_insertion_count = soft_insertion_count state.suspended_intent = suspended_intent state.resumable_step_code = resumable_step_code state.resumable_slots_json = resumable_raw state.updated_at = now return state async def reset( session: AsyncSession, thread_id: int, *, new_intent_code: str | None, new_step_code: str | None = None, ) -> ThreadState: """Сбросить шаг и слоты треда, выставить новую ветку (при смене intent).""" return await upsert( session, thread_id, intent_code=new_intent_code, step=0, step_code=new_step_code, slots={}, handoff_count=0, )