You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

374 lines
14 KiB

import os
import cv2
import numpy as np
def detect_vertebrae(
image,
model_results,
labelmap,
debug_save_path=None,
enable_postprocessing=True,
interpolate_missing=False,
spine_sequence=None,
):
"""
Detect vertebrae with optional postprocessing.
Returns:
dict[label] = [x_center, y_center, width, height, angle, confidence]
"""
default_spine_sequence = [
"T1", "T2", "T3", "T4", "T5", "T6", "T7", "T8", "T9", "T10", "T11", "T12",
"L1", "L2", "L3", "L4", "L5",
]
if spine_sequence is None:
spine_sequence = default_spine_sequence
label_to_index = {label: idx for idx, label in enumerate(spine_sequence)}
best_boxes = {}
colors = [
(0, 0, 255), # Red
(0, 255, 0), # Green
(255, 0, 0), # Blue
(0, 255, 255), # Yellow
(255, 0, 255), # Magenta
(255, 255, 0), # Cyan
(0, 165, 255), # Orange
(255, 20, 147), # Pink
(147, 20, 255), # Violet
(0, 215, 255), # Gold
(255, 215, 0), # Turquoise
(255, 105, 180), # Hot pink
(0, 255, 127), # Spring green
(255, 69, 0), # Orange red
(72, 61, 139), # Dark slate blue
(47, 255, 173), # Aquamarine
(255, 140, 0), # Dark orange
]
if model_results:
result = model_results[0]
if hasattr(result, "obb") and result.obb is not None and len(result.obb) > 0:
obb_data = result.obb
for i in range(len(obb_data)):
try:
conf = float(obb_data.conf[i].cpu().numpy())
cls_id = int(obb_data.cls[i].cpu().numpy())
label = labelmap[cls_id] if cls_id < len(labelmap) else f"Unknown_{cls_id}"
if hasattr(obb_data, "xyxyxyxy"):
box_points = obb_data.xyxyxyxy[i].cpu().numpy()
elif hasattr(obb_data, "xyxyxyxyn"):
box_points = obb_data.xyxyxyxyn[i].cpu().numpy()
h, w = image.shape[:2]
box_points = box_points.reshape(4, 2)
box_points[:, 0] *= w
box_points[:, 1] *= h
box_points = box_points.flatten()
else:
box_points = result.obb.xyxyxyxy[i].cpu().numpy()
points = box_points.reshape(4, 2).astype(np.float32)
rect = cv2.minAreaRect(points)
center, size, angle = rect
x_center, y_center = center
width, height = size
if width < height:
angle += 90
width, height = height, width
result_box = np.array([x_center, y_center, width, height, angle, conf], dtype=np.float32)
if label not in best_boxes or conf > best_boxes[label][-1]:
best_boxes[label] = result_box
except Exception as exc:
print(f"Box parse error: {exc}")
continue
if enable_postprocessing and best_boxes:
boxes_list = [(label, box) for label, box in best_boxes.items()]
boxes_list.sort(key=lambda x: x[1][-1], reverse=True)
kept_boxes = []
for label, box in boxes_list:
x_center, y_center, width, height, _, _ = box
is_duplicate = False
for kept_label, kept_box in kept_boxes:
xk, yk, wk, hk, _, _ = kept_box
dist = np.sqrt((x_center - xk) ** 2 + (y_center - yk) ** 2)
threshold = 0.3 * min(height, hk)
if dist < threshold:
is_duplicate = True
print(f"Duplicate removed: {label} overlaps with {kept_label}")
break
if not is_duplicate:
kept_boxes.append((label, box))
kept_boxes.sort(key=lambda x: x[1][1])
if kept_boxes:
valid_boxes = []
last_index = -1
last_y = -1
y_gaps = []
for label, box in kept_boxes:
x_center, y_center, width, height, angle, conf = box
if label not in label_to_index:
print(f"Unknown vertebra skipped: {label}")
continue
current_index = label_to_index[label]
if last_index == -1:
valid_boxes.append((label, box))
last_index = current_index
last_y = y_center
continue
expected_index = last_index + 1
if current_index == expected_index:
valid_boxes.append((label, box))
y_gaps.append(y_center - last_y)
last_index = current_index
last_y = y_center
continue
if current_index > expected_index:
valid_boxes.append((label, box))
last_index = current_index
last_y = y_center
continue
avg_gap = np.median(y_gaps) if y_gaps else 50.0
expected_y = last_y + avg_gap
y_diff = abs(y_center - expected_y)
if y_diff < avg_gap * 0.6 and expected_index < len(spine_sequence):
new_label = spine_sequence[expected_index]
new_box = box.copy()
new_box[5] = -1.0
valid_boxes.append((new_label, new_box))
y_gaps.append(y_center - last_y)
print(f"Order corrected: {label} -> {new_label} (conf=-1)")
last_index = expected_index
last_y = y_center
continue
print(f"Out-of-order skipped: {label} after {valid_boxes[-1][0]}")
if interpolate_missing and len(valid_boxes) > 1:
valid_boxes.sort(key=lambda x: x[1][1])
gaps = []
for i in range(1, len(valid_boxes)):
prev_y = valid_boxes[i - 1][1][1]
curr_y = valid_boxes[i][1][1]
gaps.append(curr_y - prev_y)
avg_gap = np.median(gaps) if gaps else 0
max_allowed_gap = avg_gap * 1.8 if avg_gap > 0 else float("inf")
reliable_angles = []
reliable_ratios = []
for _, box in valid_boxes:
_, _, width, height, angle, conf = box
if conf > 0:
norm_angle = angle % 180
if norm_angle > 90:
norm_angle -= 180
reliable_angles.append(norm_angle)
aspect_ratio = width / max(height, 1)
reliable_ratios.append(aspect_ratio)
median_angle = np.median(reliable_angles) if reliable_angles else 0.0
median_ratio = np.median(reliable_ratios) if reliable_ratios else 2.0
new_boxes = []
for i in range(len(valid_boxes)):
current_label, current_box = valid_boxes[i]
current_idx = label_to_index[current_label]
new_boxes.append((current_label, current_box))
if i < len(valid_boxes) - 1:
next_label, next_box = valid_boxes[i + 1]
next_idx = label_to_index[next_label]
y_gap = next_box[1] - current_box[1]
index_gap = next_idx - current_idx
if index_gap > 1 and y_gap > max_allowed_gap * 0.7:
num_missing = index_gap - 1
print(f"Missing between {current_label} and {next_label}: {num_missing}")
x1, y1, w1, h1, ang1, _ = current_box
x2, y2, w2, h2, ang2, _ = next_box
for k in range(1, num_missing + 1):
missing_idx = current_idx + k
if missing_idx >= len(spine_sequence):
continue
missing_label = spine_sequence[missing_idx]
fraction = k / index_gap
x_center = (x1 + x2) / 2
y_center = y1 + fraction * (y2 - y1)
avg_width = (w1 + w2) / 2
avg_height = (h1 + h2) / 2
if median_ratio > 1:
width = avg_width
height = width / median_ratio
else:
height = avg_height
width = height * median_ratio
norm_ang1 = ang1 % 180
norm_ang2 = ang2 % 180
if abs(norm_ang1 - norm_ang2) > 90:
if norm_ang1 > norm_ang2:
norm_ang1 -= 180
else:
norm_ang2 -= 180
angle = norm_ang1 + fraction * (norm_ang2 - norm_ang1)
angle = (angle + 90) % 180 - 90
if reliable_angles:
angle = 0.7 * angle + 0.3 * median_angle
if height > width:
width, height = height, width
angle = (angle + 90) % 180
conf = -0.5
interpolated_box = np.array(
[x_center, y_center, width, height, angle, conf],
dtype=np.float32,
)
print(
"Interpolated: "
f"{missing_label} x={x_center:.1f} y={y_center:.1f} "
f"w={width:.1f} h={height:.1f} angle={angle:.2f}"
)
new_boxes.append((missing_label, interpolated_box))
new_boxes.sort(key=lambda x: x[1][1])
valid_boxes = new_boxes
best_boxes = {label: box for label, box in valid_boxes}
if debug_save_path:
if len(image.shape) == 2:
vis_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
elif image.shape[2] == 3:
vis_image = image.copy()
else:
vis_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
for idx, (label, box) in enumerate(best_boxes.items()):
try:
x_center, y_center, width, height, angle, conf = box
color = colors[idx % len(colors)]
is_corrected = conf < 0
thickness = 4 if is_corrected else 3
alpha = 0.35 if is_corrected else 0.2
rect = ((x_center, y_center), (width, height), angle)
box_points = cv2.boxPoints(rect)
box_points = np.int64(box_points)
contour_color = (0, 0, 255) if is_corrected else color
cv2.drawContours(vis_image, [box_points], 0, contour_color, thickness)
overlay = vis_image.copy()
fill_color = (0, 0, 255) if is_corrected else color
cv2.fillPoly(overlay, [box_points], fill_color)
cv2.addWeighted(overlay, alpha, vis_image, 1 - alpha, 0, vis_image)
if is_corrected:
text = f"{label} CV({conf:.1f})"
else:
text = f"{label} ({conf:.3f})"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 2
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
text_x = int(x_center - text_width / 2)
text_y = int(y_center + text_height / 2)
padding = 5
bg_rect = [
(text_x - padding, text_y - text_height - padding),
(text_x + text_width + padding, text_y + padding),
]
text_overlay = vis_image.copy()
bg_color = (0, 0, 100) if is_corrected else (0, 0, 0)
cv2.rectangle(text_overlay, bg_rect[0], bg_rect[1], bg_color, -1)
text_alpha = 0.7
cv2.addWeighted(text_overlay, text_alpha, vis_image, 1 - text_alpha, 0, vis_image)
text_color = (255, 255, 255) if not is_corrected else (255, 200, 200)
cv2.putText(
vis_image,
text,
(text_x, text_y),
font,
font_scale,
text_color,
thickness,
)
center_color = (0, 0, 255) if is_corrected else (255, 255, 255)
cv2.circle(vis_image, (int(x_center), int(y_center)), 5, center_color, -1)
cv2.circle(vis_image, (int(x_center), int(y_center)), 2, (0, 0, 0), -1)
except Exception as exc:
print(f"Draw error for {label}: {exc}")
continue
dir_path = os.path.dirname(debug_save_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
try:
if len(vis_image.shape) == 2:
vis_image = cv2.cvtColor(vis_image, cv2.COLOR_GRAY2BGR)
cv2.imwrite(debug_save_path, vis_image, [cv2.IMWRITE_JPEG_QUALITY, 95])
print(f"Debug saved: {debug_save_path}")
print(f"Detections: {len(best_boxes)}")
corrected_count = sum(1 for box in best_boxes.values() if box[5] < 0)
if corrected_count > 0:
print(f"Corrected order: {corrected_count}")
except Exception as exc:
print(f"Debug save error: {exc}")
return {label: box.astype(float).tolist() for label, box in best_boxes.items()}