import os import threading from typing import Optional import torch import torch.nn as nn from torchvision import models from ultralytics import YOLO from config import load_config _LOCK = threading.Lock() _CACHE = {} _CFG = load_config() def _cache_key(kind: str, path: str, device: Optional[str]) -> str: return f"{kind}|{os.path.abspath(path)}|{device or 'auto'}" def _load_shufflenet(path: str, device: Optional[str]): """Загружает shufflenet‑классификатор из state_dict""" sd = torch.load(path, map_location=device or "cpu") num_classes = sd["fc.weight"].shape[0] candidates = [ models.shufflenet_v2_x0_5(weights=None), models.shufflenet_v2_x1_0(weights=None), ] model = None for m in candidates: m.fc = nn.Linear(m.fc.in_features, num_classes) try: m.load_state_dict(sd, strict=True) model = m break except Exception: continue if model is None: raise RuntimeError(f"Cannot load state_dict from {path} into supported shufflenet variants") model.eval() if device: model.to(device) return model def _load_mobilenet_v3_large(path: str, device: Optional[str]): """Загружает mobilenet_v3_large из чекпойнта, настраивает число классов, переносит на device, ставит eval()""" ckpt = torch.load(path, map_location=device or "cpu") sd = ckpt.get("model_state", ckpt) num_classes = sd["classifier.3.weight"].shape[0] model = models.mobilenet_v3_large(weights=None) model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes) model.load_state_dict(sd, strict=True) model.eval() if device: model.to(device) return model def get_detection_model(device: Optional[str] = None, model_path: Optional[str] = None): """Получает YOLO‑детектор: берёт путь (из аргумента или конфига), достаёт из кеша или загружает и кэширует""" path = model_path or _CFG.models.detection key = _cache_key("detection", path, device) with _LOCK: if key not in _CACHE: model = YOLO(path) if device: try: model.to(device) except Exception: if hasattr(model, "model"): model.model.to(device) _CACHE[key] = model return _CACHE[key] def get_classification_model(device: Optional[str] = None, model_path: Optional[str] = None): """Получает классификатор stitched/single (mobilenet): кеш + загрузка при первом вызове""" path = model_path or _CFG.models.classification key = _cache_key("classification", path, device) with _LOCK: if key not in _CACHE: _CACHE[key] = _load_mobilenet_v3_large(path, device) return _CACHE[key] def get_projection_model(device: Optional[str] = None, model_path: Optional[str] = None): """Получает классификатор проекции (shufflenet): кеш + загрузка при первом вызове""" path = model_path or _CFG.models.projection key = _cache_key("projection", path, device) with _LOCK: if key not in _CACHE: _CACHE[key] = _load_shufflenet(path, device) return _CACHE[key]