You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
139 lines
4.5 KiB
139 lines
4.5 KiB
import os |
|
|
|
import cv2 |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image, ImageOps |
|
from torchvision import transforms |
|
from torchvision.transforms import InterpolationMode |
|
|
|
STITCH_IMG_SIZE = 512 |
|
STITCH_THRESHOLD = float(os.environ.get("STITCH_THRESHOLD", 0.8)) |
|
|
|
TRIM_THR = 30 |
|
TRIM_BLACK_FRAC = 0.97 |
|
TRIM_MAX_PX = 600 |
|
|
|
def pad_to_square(img: Image.Image, fill: int = 0) -> Image.Image: |
|
"""Дополненяет изображения полями до квадратного формата, чтобы не искажать при масштабировании.""" |
|
w, h = img.size |
|
if w == h: |
|
return img |
|
side = max(w, h) |
|
pl = (side - w) // 2 |
|
pr = side - w - pl |
|
pt = (side - h) // 2 |
|
pb = side - h - pt |
|
return ImageOps.expand(img, border=(pl, pt, pr, pb), fill=fill) |
|
|
|
def trim_black_frame(img: Image.Image) -> Image.Image: |
|
"""Обрезает чёрную рамку по краям; если обрезка слишком агрессивная — возвращает оригинал.""" |
|
g = img.convert("L") |
|
w, h = g.size |
|
px = g.load() |
|
step_y = max(1, h // 180) |
|
step_x = max(1, w // 180) |
|
|
|
def col_black_frac(x: int) -> float: |
|
total = 0 |
|
black = 0 |
|
for y in range(0, h, step_y): |
|
total += 1 |
|
if px[x, y] <= TRIM_THR: |
|
black += 1 |
|
return black / max(1, total) |
|
|
|
def row_black_frac(y: int) -> float: |
|
total = 0 |
|
black = 0 |
|
for x in range(0, w, step_x): |
|
total += 1 |
|
if px[x, y] <= TRIM_THR: |
|
black += 1 |
|
return black / max(1, total) |
|
|
|
left = 0 |
|
for x in range(w): |
|
if col_black_frac(x) < TRIM_BLACK_FRAC: |
|
break |
|
left += 1 |
|
if left >= TRIM_MAX_PX: |
|
break |
|
|
|
right = 0 |
|
for x in range(w - 1, -1, -1): |
|
if col_black_frac(x) < TRIM_BLACK_FRAC: |
|
break |
|
right += 1 |
|
if right >= TRIM_MAX_PX: |
|
break |
|
|
|
top = 0 |
|
for y in range(h): |
|
if row_black_frac(y) < TRIM_BLACK_FRAC: |
|
break |
|
top += 1 |
|
if top >= TRIM_MAX_PX: |
|
break |
|
|
|
bottom = 0 |
|
for y in range(h - 1, -1, -1): |
|
if row_black_frac(y) < TRIM_BLACK_FRAC: |
|
break |
|
bottom += 1 |
|
if bottom >= TRIM_MAX_PX: |
|
break |
|
|
|
x1 = min(left, w - 2) |
|
y1 = min(top, h - 2) |
|
x2 = max(x1 + 2, w - right) |
|
y2 = max(y1 + 2, h - bottom) |
|
|
|
if x2 <= x1 + 1 or y2 <= y1 + 1: |
|
return img |
|
if (x2 - x1) < w * 0.35 or (y2 - y1) < h * 0.35: |
|
return img |
|
|
|
return img.crop((x1, y1, x2, y2)) |
|
|
|
STITCH_TFM = transforms.Compose([ |
|
transforms.Lambda(lambda arr: Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))), |
|
transforms.Lambda(trim_black_frame), |
|
transforms.Lambda(pad_to_square), |
|
transforms.Resize((STITCH_IMG_SIZE, STITCH_IMG_SIZE), interpolation=InterpolationMode.BILINEAR), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
|
|
PROJ_TFM = transforms.Compose([ |
|
transforms.Lambda(lambda arr: Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))), |
|
transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
|
|
def classify_stitched(model, img_bgr): |
|
"""Модель распознает склеенное/одиночное изображение. Возвращает класс 0(склеенный)/1(одиночный) и уверенность""" |
|
device = next(model.parameters()).device |
|
x = STITCH_TFM(img_bgr).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
logits = model(x) |
|
if logits.shape[1] == 1 or (logits.ndim == 2 and logits.shape[-1] == 1): |
|
prob = torch.sigmoid(logits)[0, 0].item() |
|
pred = 1 if prob >= STITCH_THRESHOLD else 0 |
|
conf = prob if pred == 1 else (1 - prob) |
|
return pred, float(conf) |
|
probs = F.softmax(logits, dim=1) |
|
conf, pred = probs.max(dim=1) |
|
return int(pred.item()), float(conf.item()) |
|
|
|
def classify_projection(model, img_bgr): |
|
"""Определяет проекцию 0(боковая)/1(фронтальная), возвращает класс и уверенность.""" |
|
device = next(model.parameters()).device |
|
x = PROJ_TFM(img_bgr).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
logits = model(x) |
|
probs = F.softmax(logits, dim=1) |
|
conf, pred = probs.max(dim=1) |
|
return int(pred.item()), float(conf.item()) |
|
|
|
|