You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
4.5 KiB

import os
import cv2
import torch
import torch.nn.functional as F
from PIL import Image, ImageOps
from torchvision import transforms
from torchvision.transforms import InterpolationMode
STITCH_IMG_SIZE = 512
STITCH_THRESHOLD = float(os.environ.get("STITCH_THRESHOLD", 0.8))
TRIM_THR = 30
TRIM_BLACK_FRAC = 0.97
TRIM_MAX_PX = 600
def pad_to_square(img: Image.Image, fill: int = 0) -> Image.Image:
"""Дополненяет изображения полями до квадратного формата, чтобы не искажать при масштабировании."""
w, h = img.size
if w == h:
return img
side = max(w, h)
pl = (side - w) // 2
pr = side - w - pl
pt = (side - h) // 2
pb = side - h - pt
return ImageOps.expand(img, border=(pl, pt, pr, pb), fill=fill)
def trim_black_frame(img: Image.Image) -> Image.Image:
"""Обрезает чёрную рамку по краям; если обрезка слишком агрессивная — возвращает оригинал."""
g = img.convert("L")
w, h = g.size
px = g.load()
step_y = max(1, h // 180)
step_x = max(1, w // 180)
def col_black_frac(x: int) -> float:
total = 0
black = 0
for y in range(0, h, step_y):
total += 1
if px[x, y] <= TRIM_THR:
black += 1
return black / max(1, total)
def row_black_frac(y: int) -> float:
total = 0
black = 0
for x in range(0, w, step_x):
total += 1
if px[x, y] <= TRIM_THR:
black += 1
return black / max(1, total)
left = 0
for x in range(w):
if col_black_frac(x) < TRIM_BLACK_FRAC:
break
left += 1
if left >= TRIM_MAX_PX:
break
right = 0
for x in range(w - 1, -1, -1):
if col_black_frac(x) < TRIM_BLACK_FRAC:
break
right += 1
if right >= TRIM_MAX_PX:
break
top = 0
for y in range(h):
if row_black_frac(y) < TRIM_BLACK_FRAC:
break
top += 1
if top >= TRIM_MAX_PX:
break
bottom = 0
for y in range(h - 1, -1, -1):
if row_black_frac(y) < TRIM_BLACK_FRAC:
break
bottom += 1
if bottom >= TRIM_MAX_PX:
break
x1 = min(left, w - 2)
y1 = min(top, h - 2)
x2 = max(x1 + 2, w - right)
y2 = max(y1 + 2, h - bottom)
if x2 <= x1 + 1 or y2 <= y1 + 1:
return img
if (x2 - x1) < w * 0.35 or (y2 - y1) < h * 0.35:
return img
return img.crop((x1, y1, x2, y2))
STITCH_TFM = transforms.Compose([
transforms.Lambda(lambda arr: Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))),
transforms.Lambda(trim_black_frame),
transforms.Lambda(pad_to_square),
transforms.Resize((STITCH_IMG_SIZE, STITCH_IMG_SIZE), interpolation=InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
PROJ_TFM = transforms.Compose([
transforms.Lambda(lambda arr: Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))),
transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def classify_stitched(model, img_bgr):
"""Модель распознает склеенное/одиночное изображение. Возвращает класс 0(склеенный)/1(одиночный) и уверенность"""
device = next(model.parameters()).device
x = STITCH_TFM(img_bgr).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
if logits.shape[1] == 1 or (logits.ndim == 2 and logits.shape[-1] == 1):
prob = torch.sigmoid(logits)[0, 0].item()
pred = 1 if prob >= STITCH_THRESHOLD else 0
conf = prob if pred == 1 else (1 - prob)
return pred, float(conf)
probs = F.softmax(logits, dim=1)
conf, pred = probs.max(dim=1)
return int(pred.item()), float(conf.item())
def classify_projection(model, img_bgr):
"""Определяет проекцию 0(боковая)/1(фронтальная), возвращает класс и уверенность."""
device = next(model.parameters()).device
x = PROJ_TFM(img_bgr).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits, dim=1)
conf, pred = probs.max(dim=1)
return int(pred.item()), float(conf.item())