You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
88 lines
3.4 KiB
88 lines
3.4 KiB
import os |
|
import threading |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torchvision import models |
|
from ultralytics import YOLO |
|
|
|
from config import load_config |
|
|
|
_LOCK = threading.Lock() |
|
_CACHE = {} |
|
_CFG = load_config() |
|
|
|
def _cache_key(kind: str, path: str, device: Optional[str]) -> str: |
|
return f"{kind}|{os.path.abspath(path)}|{device or 'auto'}" |
|
|
|
def _load_shufflenet(path: str, device: Optional[str]): |
|
"""Загружает shufflenet‑классификатор из state_dict""" |
|
sd = torch.load(path, map_location=device or "cpu") |
|
num_classes = sd["fc.weight"].shape[0] |
|
candidates = [ |
|
models.shufflenet_v2_x0_5(weights=None), |
|
models.shufflenet_v2_x1_0(weights=None), |
|
] |
|
model = None |
|
for m in candidates: |
|
m.fc = nn.Linear(m.fc.in_features, num_classes) |
|
try: |
|
m.load_state_dict(sd, strict=True) |
|
model = m |
|
break |
|
except Exception: |
|
continue |
|
if model is None: |
|
raise RuntimeError(f"Cannot load state_dict from {path} into supported shufflenet variants") |
|
model.eval() |
|
if device: |
|
model.to(device) |
|
return model |
|
|
|
def _load_mobilenet_v3_large(path: str, device: Optional[str]): |
|
"""Загружает mobilenet_v3_large из чекпойнта, настраивает число классов, переносит на device, ставит eval()""" |
|
ckpt = torch.load(path, map_location=device or "cpu") |
|
sd = ckpt.get("model_state", ckpt) |
|
num_classes = sd["classifier.3.weight"].shape[0] |
|
model = models.mobilenet_v3_large(weights=None) |
|
model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes) |
|
model.load_state_dict(sd, strict=True) |
|
model.eval() |
|
if device: |
|
model.to(device) |
|
return model |
|
|
|
def get_detection_model(device: Optional[str] = None, model_path: Optional[str] = None): |
|
"""Получает YOLO‑детектор: берёт путь (из аргумента или конфига), достаёт из кеша или загружает и кэширует""" |
|
path = model_path or _CFG.models.detection |
|
key = _cache_key("detection", path, device) |
|
with _LOCK: |
|
if key not in _CACHE: |
|
model = YOLO(path) |
|
if device: |
|
try: |
|
model.to(device) |
|
except Exception: |
|
if hasattr(model, "model"): |
|
model.model.to(device) |
|
_CACHE[key] = model |
|
return _CACHE[key] |
|
|
|
def get_classification_model(device: Optional[str] = None, model_path: Optional[str] = None): |
|
"""Получает классификатор stitched/single (mobilenet): кеш + загрузка при первом вызове""" |
|
path = model_path or _CFG.models.classification |
|
key = _cache_key("classification", path, device) |
|
with _LOCK: |
|
if key not in _CACHE: |
|
_CACHE[key] = _load_mobilenet_v3_large(path, device) |
|
return _CACHE[key] |
|
|
|
def get_projection_model(device: Optional[str] = None, model_path: Optional[str] = None): |
|
"""Получает классификатор проекции (shufflenet): кеш + загрузка при первом вызове""" |
|
path = model_path or _CFG.models.projection |
|
key = _cache_key("projection", path, device) |
|
with _LOCK: |
|
if key not in _CACHE: |
|
_CACHE[key] = _load_shufflenet(path, device) |
|
return _CACHE[key]
|
|
|