import os import cv2 import torch import torch.nn.functional as F from PIL import Image, ImageOps from torchvision import transforms from torchvision.transforms import InterpolationMode STITCH_IMG_SIZE = 512 STITCH_THRESHOLD = float(os.environ.get("STITCH_THRESHOLD", 0.8)) TRIM_THR = 30 TRIM_BLACK_FRAC = 0.97 TRIM_MAX_PX = 600 def pad_to_square(img: Image.Image, fill: int = 0) -> Image.Image: """Дополненяет изображения полями до квадратного формата, чтобы не искажать при масштабировании.""" w, h = img.size if w == h: return img side = max(w, h) pl = (side - w) // 2 pr = side - w - pl pt = (side - h) // 2 pb = side - h - pt return ImageOps.expand(img, border=(pl, pt, pr, pb), fill=fill) def trim_black_frame(img: Image.Image) -> Image.Image: """Обрезает чёрную рамку по краям; если обрезка слишком агрессивная — возвращает оригинал.""" g = img.convert("L") w, h = g.size px = g.load() step_y = max(1, h // 180) step_x = max(1, w // 180) def col_black_frac(x: int) -> float: total = 0 black = 0 for y in range(0, h, step_y): total += 1 if px[x, y] <= TRIM_THR: black += 1 return black / max(1, total) def row_black_frac(y: int) -> float: total = 0 black = 0 for x in range(0, w, step_x): total += 1 if px[x, y] <= TRIM_THR: black += 1 return black / max(1, total) left = 0 for x in range(w): if col_black_frac(x) < TRIM_BLACK_FRAC: break left += 1 if left >= TRIM_MAX_PX: break right = 0 for x in range(w - 1, -1, -1): if col_black_frac(x) < TRIM_BLACK_FRAC: break right += 1 if right >= TRIM_MAX_PX: break top = 0 for y in range(h): if row_black_frac(y) < TRIM_BLACK_FRAC: break top += 1 if top >= TRIM_MAX_PX: break bottom = 0 for y in range(h - 1, -1, -1): if row_black_frac(y) < TRIM_BLACK_FRAC: break bottom += 1 if bottom >= TRIM_MAX_PX: break x1 = min(left, w - 2) y1 = min(top, h - 2) x2 = max(x1 + 2, w - right) y2 = max(y1 + 2, h - bottom) if x2 <= x1 + 1 or y2 <= y1 + 1: return img if (x2 - x1) < w * 0.35 or (y2 - y1) < h * 0.35: return img return img.crop((x1, y1, x2, y2)) STITCH_TFM = transforms.Compose([ transforms.Lambda(lambda arr: Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))), transforms.Lambda(trim_black_frame), transforms.Lambda(pad_to_square), transforms.Resize((STITCH_IMG_SIZE, STITCH_IMG_SIZE), interpolation=InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) PROJ_TFM = transforms.Compose([ transforms.Lambda(lambda arr: Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))), transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) def classify_stitched(model, img_bgr): """Модель распознает склеенное/одиночное изображение. Возвращает класс 0(склеенный)/1(одиночный) и уверенность""" device = next(model.parameters()).device x = STITCH_TFM(img_bgr).unsqueeze(0).to(device) with torch.no_grad(): logits = model(x) if logits.shape[1] == 1 or (logits.ndim == 2 and logits.shape[-1] == 1): prob = torch.sigmoid(logits)[0, 0].item() pred = 1 if prob >= STITCH_THRESHOLD else 0 conf = prob if pred == 1 else (1 - prob) return pred, float(conf) probs = F.softmax(logits, dim=1) conf, pred = probs.max(dim=1) return int(pred.item()), float(conf.item()) def classify_projection(model, img_bgr): """Определяет проекцию 0(боковая)/1(фронтальная), возвращает класс и уверенность.""" device = next(model.parameters()).device x = PROJ_TFM(img_bgr).unsqueeze(0).to(device) with torch.no_grad(): logits = model(x) probs = F.softmax(logits, dim=1) conf, pred = probs.max(dim=1) return int(pred.item()), float(conf.item())