You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
374 lines
14 KiB
374 lines
14 KiB
import os |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
|
|
def detect_vertebrae( |
|
image, |
|
model_results, |
|
labelmap, |
|
debug_save_path=None, |
|
enable_postprocessing=True, |
|
interpolate_missing=False, |
|
spine_sequence=None, |
|
): |
|
""" |
|
Detect vertebrae with optional postprocessing. |
|
|
|
Returns: |
|
dict[label] = [x_center, y_center, width, height, angle, confidence] |
|
""" |
|
default_spine_sequence = [ |
|
"T1", "T2", "T3", "T4", "T5", "T6", "T7", "T8", "T9", "T10", "T11", "T12", |
|
"L1", "L2", "L3", "L4", "L5", |
|
] |
|
|
|
if spine_sequence is None: |
|
spine_sequence = default_spine_sequence |
|
|
|
label_to_index = {label: idx for idx, label in enumerate(spine_sequence)} |
|
|
|
best_boxes = {} |
|
|
|
colors = [ |
|
(0, 0, 255), # Red |
|
(0, 255, 0), # Green |
|
(255, 0, 0), # Blue |
|
(0, 255, 255), # Yellow |
|
(255, 0, 255), # Magenta |
|
(255, 255, 0), # Cyan |
|
(0, 165, 255), # Orange |
|
(255, 20, 147), # Pink |
|
(147, 20, 255), # Violet |
|
(0, 215, 255), # Gold |
|
(255, 215, 0), # Turquoise |
|
(255, 105, 180), # Hot pink |
|
(0, 255, 127), # Spring green |
|
(255, 69, 0), # Orange red |
|
(72, 61, 139), # Dark slate blue |
|
(47, 255, 173), # Aquamarine |
|
(255, 140, 0), # Dark orange |
|
] |
|
|
|
if model_results: |
|
result = model_results[0] |
|
if hasattr(result, "obb") and result.obb is not None and len(result.obb) > 0: |
|
obb_data = result.obb |
|
for i in range(len(obb_data)): |
|
try: |
|
conf = float(obb_data.conf[i].cpu().numpy()) |
|
cls_id = int(obb_data.cls[i].cpu().numpy()) |
|
|
|
label = labelmap[cls_id] if cls_id < len(labelmap) else f"Unknown_{cls_id}" |
|
|
|
if hasattr(obb_data, "xyxyxyxy"): |
|
box_points = obb_data.xyxyxyxy[i].cpu().numpy() |
|
elif hasattr(obb_data, "xyxyxyxyn"): |
|
box_points = obb_data.xyxyxyxyn[i].cpu().numpy() |
|
h, w = image.shape[:2] |
|
box_points = box_points.reshape(4, 2) |
|
box_points[:, 0] *= w |
|
box_points[:, 1] *= h |
|
box_points = box_points.flatten() |
|
else: |
|
box_points = result.obb.xyxyxyxy[i].cpu().numpy() |
|
|
|
points = box_points.reshape(4, 2).astype(np.float32) |
|
rect = cv2.minAreaRect(points) |
|
center, size, angle = rect |
|
|
|
x_center, y_center = center |
|
width, height = size |
|
|
|
if width < height: |
|
angle += 90 |
|
width, height = height, width |
|
|
|
result_box = np.array([x_center, y_center, width, height, angle, conf], dtype=np.float32) |
|
|
|
if label not in best_boxes or conf > best_boxes[label][-1]: |
|
best_boxes[label] = result_box |
|
except Exception as exc: |
|
print(f"Box parse error: {exc}") |
|
continue |
|
|
|
if enable_postprocessing and best_boxes: |
|
boxes_list = [(label, box) for label, box in best_boxes.items()] |
|
|
|
boxes_list.sort(key=lambda x: x[1][-1], reverse=True) |
|
kept_boxes = [] |
|
for label, box in boxes_list: |
|
x_center, y_center, width, height, _, _ = box |
|
is_duplicate = False |
|
|
|
for kept_label, kept_box in kept_boxes: |
|
xk, yk, wk, hk, _, _ = kept_box |
|
dist = np.sqrt((x_center - xk) ** 2 + (y_center - yk) ** 2) |
|
threshold = 0.3 * min(height, hk) |
|
|
|
if dist < threshold: |
|
is_duplicate = True |
|
print(f"Duplicate removed: {label} overlaps with {kept_label}") |
|
break |
|
|
|
if not is_duplicate: |
|
kept_boxes.append((label, box)) |
|
|
|
kept_boxes.sort(key=lambda x: x[1][1]) |
|
|
|
if kept_boxes: |
|
valid_boxes = [] |
|
last_index = -1 |
|
last_y = -1 |
|
y_gaps = [] |
|
|
|
for label, box in kept_boxes: |
|
x_center, y_center, width, height, angle, conf = box |
|
|
|
if label not in label_to_index: |
|
print(f"Unknown vertebra skipped: {label}") |
|
continue |
|
|
|
current_index = label_to_index[label] |
|
|
|
if last_index == -1: |
|
valid_boxes.append((label, box)) |
|
last_index = current_index |
|
last_y = y_center |
|
continue |
|
|
|
expected_index = last_index + 1 |
|
|
|
if current_index == expected_index: |
|
valid_boxes.append((label, box)) |
|
y_gaps.append(y_center - last_y) |
|
last_index = current_index |
|
last_y = y_center |
|
continue |
|
|
|
if current_index > expected_index: |
|
valid_boxes.append((label, box)) |
|
last_index = current_index |
|
last_y = y_center |
|
continue |
|
|
|
avg_gap = np.median(y_gaps) if y_gaps else 50.0 |
|
expected_y = last_y + avg_gap |
|
|
|
y_diff = abs(y_center - expected_y) |
|
if y_diff < avg_gap * 0.6 and expected_index < len(spine_sequence): |
|
new_label = spine_sequence[expected_index] |
|
new_box = box.copy() |
|
new_box[5] = -1.0 |
|
|
|
valid_boxes.append((new_label, new_box)) |
|
y_gaps.append(y_center - last_y) |
|
print(f"Order corrected: {label} -> {new_label} (conf=-1)") |
|
|
|
last_index = expected_index |
|
last_y = y_center |
|
continue |
|
|
|
print(f"Out-of-order skipped: {label} after {valid_boxes[-1][0]}") |
|
|
|
if interpolate_missing and len(valid_boxes) > 1: |
|
valid_boxes.sort(key=lambda x: x[1][1]) |
|
|
|
gaps = [] |
|
for i in range(1, len(valid_boxes)): |
|
prev_y = valid_boxes[i - 1][1][1] |
|
curr_y = valid_boxes[i][1][1] |
|
gaps.append(curr_y - prev_y) |
|
|
|
avg_gap = np.median(gaps) if gaps else 0 |
|
max_allowed_gap = avg_gap * 1.8 if avg_gap > 0 else float("inf") |
|
|
|
reliable_angles = [] |
|
reliable_ratios = [] |
|
for _, box in valid_boxes: |
|
_, _, width, height, angle, conf = box |
|
if conf > 0: |
|
norm_angle = angle % 180 |
|
if norm_angle > 90: |
|
norm_angle -= 180 |
|
reliable_angles.append(norm_angle) |
|
|
|
aspect_ratio = width / max(height, 1) |
|
reliable_ratios.append(aspect_ratio) |
|
|
|
median_angle = np.median(reliable_angles) if reliable_angles else 0.0 |
|
median_ratio = np.median(reliable_ratios) if reliable_ratios else 2.0 |
|
|
|
new_boxes = [] |
|
for i in range(len(valid_boxes)): |
|
current_label, current_box = valid_boxes[i] |
|
current_idx = label_to_index[current_label] |
|
new_boxes.append((current_label, current_box)) |
|
|
|
if i < len(valid_boxes) - 1: |
|
next_label, next_box = valid_boxes[i + 1] |
|
next_idx = label_to_index[next_label] |
|
|
|
y_gap = next_box[1] - current_box[1] |
|
index_gap = next_idx - current_idx |
|
|
|
if index_gap > 1 and y_gap > max_allowed_gap * 0.7: |
|
num_missing = index_gap - 1 |
|
print(f"Missing between {current_label} and {next_label}: {num_missing}") |
|
|
|
x1, y1, w1, h1, ang1, _ = current_box |
|
x2, y2, w2, h2, ang2, _ = next_box |
|
|
|
for k in range(1, num_missing + 1): |
|
missing_idx = current_idx + k |
|
if missing_idx >= len(spine_sequence): |
|
continue |
|
|
|
missing_label = spine_sequence[missing_idx] |
|
fraction = k / index_gap |
|
|
|
x_center = (x1 + x2) / 2 |
|
y_center = y1 + fraction * (y2 - y1) |
|
|
|
avg_width = (w1 + w2) / 2 |
|
avg_height = (h1 + h2) / 2 |
|
|
|
if median_ratio > 1: |
|
width = avg_width |
|
height = width / median_ratio |
|
else: |
|
height = avg_height |
|
width = height * median_ratio |
|
|
|
norm_ang1 = ang1 % 180 |
|
norm_ang2 = ang2 % 180 |
|
|
|
if abs(norm_ang1 - norm_ang2) > 90: |
|
if norm_ang1 > norm_ang2: |
|
norm_ang1 -= 180 |
|
else: |
|
norm_ang2 -= 180 |
|
|
|
angle = norm_ang1 + fraction * (norm_ang2 - norm_ang1) |
|
angle = (angle + 90) % 180 - 90 |
|
|
|
if reliable_angles: |
|
angle = 0.7 * angle + 0.3 * median_angle |
|
|
|
if height > width: |
|
width, height = height, width |
|
angle = (angle + 90) % 180 |
|
|
|
conf = -0.5 |
|
|
|
interpolated_box = np.array( |
|
[x_center, y_center, width, height, angle, conf], |
|
dtype=np.float32, |
|
) |
|
|
|
print( |
|
"Interpolated: " |
|
f"{missing_label} x={x_center:.1f} y={y_center:.1f} " |
|
f"w={width:.1f} h={height:.1f} angle={angle:.2f}" |
|
) |
|
new_boxes.append((missing_label, interpolated_box)) |
|
|
|
new_boxes.sort(key=lambda x: x[1][1]) |
|
valid_boxes = new_boxes |
|
|
|
best_boxes = {label: box for label, box in valid_boxes} |
|
|
|
if debug_save_path: |
|
if len(image.shape) == 2: |
|
vis_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) |
|
elif image.shape[2] == 3: |
|
vis_image = image.copy() |
|
else: |
|
vis_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
for idx, (label, box) in enumerate(best_boxes.items()): |
|
try: |
|
x_center, y_center, width, height, angle, conf = box |
|
|
|
color = colors[idx % len(colors)] |
|
|
|
is_corrected = conf < 0 |
|
thickness = 4 if is_corrected else 3 |
|
alpha = 0.35 if is_corrected else 0.2 |
|
|
|
rect = ((x_center, y_center), (width, height), angle) |
|
box_points = cv2.boxPoints(rect) |
|
box_points = np.int64(box_points) |
|
|
|
contour_color = (0, 0, 255) if is_corrected else color |
|
cv2.drawContours(vis_image, [box_points], 0, contour_color, thickness) |
|
|
|
overlay = vis_image.copy() |
|
fill_color = (0, 0, 255) if is_corrected else color |
|
cv2.fillPoly(overlay, [box_points], fill_color) |
|
cv2.addWeighted(overlay, alpha, vis_image, 1 - alpha, 0, vis_image) |
|
|
|
if is_corrected: |
|
text = f"{label} CV({conf:.1f})" |
|
else: |
|
text = f"{label} ({conf:.3f})" |
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
font_scale = 0.6 |
|
thickness = 2 |
|
|
|
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) |
|
text_x = int(x_center - text_width / 2) |
|
text_y = int(y_center + text_height / 2) |
|
|
|
padding = 5 |
|
bg_rect = [ |
|
(text_x - padding, text_y - text_height - padding), |
|
(text_x + text_width + padding, text_y + padding), |
|
] |
|
|
|
text_overlay = vis_image.copy() |
|
bg_color = (0, 0, 100) if is_corrected else (0, 0, 0) |
|
cv2.rectangle(text_overlay, bg_rect[0], bg_rect[1], bg_color, -1) |
|
|
|
text_alpha = 0.7 |
|
cv2.addWeighted(text_overlay, text_alpha, vis_image, 1 - text_alpha, 0, vis_image) |
|
|
|
text_color = (255, 255, 255) if not is_corrected else (255, 200, 200) |
|
cv2.putText( |
|
vis_image, |
|
text, |
|
(text_x, text_y), |
|
font, |
|
font_scale, |
|
text_color, |
|
thickness, |
|
) |
|
|
|
center_color = (0, 0, 255) if is_corrected else (255, 255, 255) |
|
cv2.circle(vis_image, (int(x_center), int(y_center)), 5, center_color, -1) |
|
cv2.circle(vis_image, (int(x_center), int(y_center)), 2, (0, 0, 0), -1) |
|
except Exception as exc: |
|
print(f"Draw error for {label}: {exc}") |
|
continue |
|
|
|
dir_path = os.path.dirname(debug_save_path) |
|
if dir_path: |
|
os.makedirs(dir_path, exist_ok=True) |
|
|
|
try: |
|
if len(vis_image.shape) == 2: |
|
vis_image = cv2.cvtColor(vis_image, cv2.COLOR_GRAY2BGR) |
|
|
|
cv2.imwrite(debug_save_path, vis_image, [cv2.IMWRITE_JPEG_QUALITY, 95]) |
|
print(f"Debug saved: {debug_save_path}") |
|
print(f"Detections: {len(best_boxes)}") |
|
|
|
corrected_count = sum(1 for box in best_boxes.values() if box[5] < 0) |
|
if corrected_count > 0: |
|
print(f"Corrected order: {corrected_count}") |
|
except Exception as exc: |
|
print(f"Debug save error: {exc}") |
|
|
|
return {label: box.astype(float).tolist() for label, box in best_boxes.items()}
|
|
|