You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
3.4 KiB

import os
import threading
from typing import Optional
import torch
import torch.nn as nn
from torchvision import models
from ultralytics import YOLO
from config import load_config
_LOCK = threading.Lock()
_CACHE = {}
_CFG = load_config()
def _cache_key(kind: str, path: str, device: Optional[str]) -> str:
return f"{kind}|{os.path.abspath(path)}|{device or 'auto'}"
def _load_shufflenet(path: str, device: Optional[str]):
"""Загружает shufflenet‑классификатор из state_dict"""
sd = torch.load(path, map_location=device or "cpu")
num_classes = sd["fc.weight"].shape[0]
candidates = [
models.shufflenet_v2_x0_5(weights=None),
models.shufflenet_v2_x1_0(weights=None),
]
model = None
for m in candidates:
m.fc = nn.Linear(m.fc.in_features, num_classes)
try:
m.load_state_dict(sd, strict=True)
model = m
break
except Exception:
continue
if model is None:
raise RuntimeError(f"Cannot load state_dict from {path} into supported shufflenet variants")
model.eval()
if device:
model.to(device)
return model
def _load_mobilenet_v3_large(path: str, device: Optional[str]):
"""Загружает mobilenet_v3_large из чекпойнта, настраивает число классов, переносит на device, ставит eval()"""
ckpt = torch.load(path, map_location=device or "cpu")
sd = ckpt.get("model_state", ckpt)
num_classes = sd["classifier.3.weight"].shape[0]
model = models.mobilenet_v3_large(weights=None)
model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
model.load_state_dict(sd, strict=True)
model.eval()
if device:
model.to(device)
return model
def get_detection_model(device: Optional[str] = None, model_path: Optional[str] = None):
"""Получает YOLO‑детектор: берёт путь (из аргумента или конфига), достаёт из кеша или загружает и кэширует"""
path = model_path or _CFG.models.detection
key = _cache_key("detection", path, device)
with _LOCK:
if key not in _CACHE:
model = YOLO(path)
if device:
try:
model.to(device)
except Exception:
if hasattr(model, "model"):
model.model.to(device)
_CACHE[key] = model
return _CACHE[key]
def get_classification_model(device: Optional[str] = None, model_path: Optional[str] = None):
"""Получает классификатор stitched/single (mobilenet): кеш + загрузка при первом вызове"""
path = model_path or _CFG.models.classification
key = _cache_key("classification", path, device)
with _LOCK:
if key not in _CACHE:
_CACHE[key] = _load_mobilenet_v3_large(path, device)
return _CACHE[key]
def get_projection_model(device: Optional[str] = None, model_path: Optional[str] = None):
"""Получает классификатор проекции (shufflenet): кеш + загрузка при первом вызове"""
path = model_path or _CFG.models.projection
key = _cache_key("projection", path, device)
with _LOCK:
if key not in _CACHE:
_CACHE[key] = _load_shufflenet(path, device)
return _CACHE[key]