import os import cv2 import numpy as np def detect_vertebrae( image, model_results, labelmap, debug_save_path=None, enable_postprocessing=True, interpolate_missing=False, spine_sequence=None, ): """ Detect vertebrae with optional postprocessing. Returns: dict[label] = [x_center, y_center, width, height, angle, confidence] """ default_spine_sequence = [ "T1", "T2", "T3", "T4", "T5", "T6", "T7", "T8", "T9", "T10", "T11", "T12", "L1", "L2", "L3", "L4", "L5", ] if spine_sequence is None: spine_sequence = default_spine_sequence label_to_index = {label: idx for idx, label in enumerate(spine_sequence)} best_boxes = {} colors = [ (0, 0, 255), # Red (0, 255, 0), # Green (255, 0, 0), # Blue (0, 255, 255), # Yellow (255, 0, 255), # Magenta (255, 255, 0), # Cyan (0, 165, 255), # Orange (255, 20, 147), # Pink (147, 20, 255), # Violet (0, 215, 255), # Gold (255, 215, 0), # Turquoise (255, 105, 180), # Hot pink (0, 255, 127), # Spring green (255, 69, 0), # Orange red (72, 61, 139), # Dark slate blue (47, 255, 173), # Aquamarine (255, 140, 0), # Dark orange ] if model_results: result = model_results[0] if hasattr(result, "obb") and result.obb is not None and len(result.obb) > 0: obb_data = result.obb for i in range(len(obb_data)): try: conf = float(obb_data.conf[i].cpu().numpy()) cls_id = int(obb_data.cls[i].cpu().numpy()) label = labelmap[cls_id] if cls_id < len(labelmap) else f"Unknown_{cls_id}" if hasattr(obb_data, "xyxyxyxy"): box_points = obb_data.xyxyxyxy[i].cpu().numpy() elif hasattr(obb_data, "xyxyxyxyn"): box_points = obb_data.xyxyxyxyn[i].cpu().numpy() h, w = image.shape[:2] box_points = box_points.reshape(4, 2) box_points[:, 0] *= w box_points[:, 1] *= h box_points = box_points.flatten() else: box_points = result.obb.xyxyxyxy[i].cpu().numpy() points = box_points.reshape(4, 2).astype(np.float32) rect = cv2.minAreaRect(points) center, size, angle = rect x_center, y_center = center width, height = size if width < height: angle += 90 width, height = height, width result_box = np.array([x_center, y_center, width, height, angle, conf], dtype=np.float32) if label not in best_boxes or conf > best_boxes[label][-1]: best_boxes[label] = result_box except Exception as exc: print(f"Box parse error: {exc}") continue if enable_postprocessing and best_boxes: boxes_list = [(label, box) for label, box in best_boxes.items()] boxes_list.sort(key=lambda x: x[1][-1], reverse=True) kept_boxes = [] for label, box in boxes_list: x_center, y_center, width, height, _, _ = box is_duplicate = False for kept_label, kept_box in kept_boxes: xk, yk, wk, hk, _, _ = kept_box dist = np.sqrt((x_center - xk) ** 2 + (y_center - yk) ** 2) threshold = 0.3 * min(height, hk) if dist < threshold: is_duplicate = True print(f"Duplicate removed: {label} overlaps with {kept_label}") break if not is_duplicate: kept_boxes.append((label, box)) kept_boxes.sort(key=lambda x: x[1][1]) if kept_boxes: valid_boxes = [] last_index = -1 last_y = -1 y_gaps = [] for label, box in kept_boxes: x_center, y_center, width, height, angle, conf = box if label not in label_to_index: print(f"Unknown vertebra skipped: {label}") continue current_index = label_to_index[label] if last_index == -1: valid_boxes.append((label, box)) last_index = current_index last_y = y_center continue expected_index = last_index + 1 if current_index == expected_index: valid_boxes.append((label, box)) y_gaps.append(y_center - last_y) last_index = current_index last_y = y_center continue if current_index > expected_index: valid_boxes.append((label, box)) last_index = current_index last_y = y_center continue avg_gap = np.median(y_gaps) if y_gaps else 50.0 expected_y = last_y + avg_gap y_diff = abs(y_center - expected_y) if y_diff < avg_gap * 0.6 and expected_index < len(spine_sequence): new_label = spine_sequence[expected_index] new_box = box.copy() new_box[5] = -1.0 valid_boxes.append((new_label, new_box)) y_gaps.append(y_center - last_y) print(f"Order corrected: {label} -> {new_label} (conf=-1)") last_index = expected_index last_y = y_center continue print(f"Out-of-order skipped: {label} after {valid_boxes[-1][0]}") if interpolate_missing and len(valid_boxes) > 1: valid_boxes.sort(key=lambda x: x[1][1]) gaps = [] for i in range(1, len(valid_boxes)): prev_y = valid_boxes[i - 1][1][1] curr_y = valid_boxes[i][1][1] gaps.append(curr_y - prev_y) avg_gap = np.median(gaps) if gaps else 0 max_allowed_gap = avg_gap * 1.8 if avg_gap > 0 else float("inf") reliable_angles = [] reliable_ratios = [] for _, box in valid_boxes: _, _, width, height, angle, conf = box if conf > 0: norm_angle = angle % 180 if norm_angle > 90: norm_angle -= 180 reliable_angles.append(norm_angle) aspect_ratio = width / max(height, 1) reliable_ratios.append(aspect_ratio) median_angle = np.median(reliable_angles) if reliable_angles else 0.0 median_ratio = np.median(reliable_ratios) if reliable_ratios else 2.0 new_boxes = [] for i in range(len(valid_boxes)): current_label, current_box = valid_boxes[i] current_idx = label_to_index[current_label] new_boxes.append((current_label, current_box)) if i < len(valid_boxes) - 1: next_label, next_box = valid_boxes[i + 1] next_idx = label_to_index[next_label] y_gap = next_box[1] - current_box[1] index_gap = next_idx - current_idx if index_gap > 1 and y_gap > max_allowed_gap * 0.7: num_missing = index_gap - 1 print(f"Missing between {current_label} and {next_label}: {num_missing}") x1, y1, w1, h1, ang1, _ = current_box x2, y2, w2, h2, ang2, _ = next_box for k in range(1, num_missing + 1): missing_idx = current_idx + k if missing_idx >= len(spine_sequence): continue missing_label = spine_sequence[missing_idx] fraction = k / index_gap x_center = (x1 + x2) / 2 y_center = y1 + fraction * (y2 - y1) avg_width = (w1 + w2) / 2 avg_height = (h1 + h2) / 2 if median_ratio > 1: width = avg_width height = width / median_ratio else: height = avg_height width = height * median_ratio norm_ang1 = ang1 % 180 norm_ang2 = ang2 % 180 if abs(norm_ang1 - norm_ang2) > 90: if norm_ang1 > norm_ang2: norm_ang1 -= 180 else: norm_ang2 -= 180 angle = norm_ang1 + fraction * (norm_ang2 - norm_ang1) angle = (angle + 90) % 180 - 90 if reliable_angles: angle = 0.7 * angle + 0.3 * median_angle if height > width: width, height = height, width angle = (angle + 90) % 180 conf = -0.5 interpolated_box = np.array( [x_center, y_center, width, height, angle, conf], dtype=np.float32, ) print( "Interpolated: " f"{missing_label} x={x_center:.1f} y={y_center:.1f} " f"w={width:.1f} h={height:.1f} angle={angle:.2f}" ) new_boxes.append((missing_label, interpolated_box)) new_boxes.sort(key=lambda x: x[1][1]) valid_boxes = new_boxes best_boxes = {label: box for label, box in valid_boxes} if debug_save_path: if len(image.shape) == 2: vis_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) elif image.shape[2] == 3: vis_image = image.copy() else: vis_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) for idx, (label, box) in enumerate(best_boxes.items()): try: x_center, y_center, width, height, angle, conf = box color = colors[idx % len(colors)] is_corrected = conf < 0 thickness = 4 if is_corrected else 3 alpha = 0.35 if is_corrected else 0.2 rect = ((x_center, y_center), (width, height), angle) box_points = cv2.boxPoints(rect) box_points = np.int64(box_points) contour_color = (0, 0, 255) if is_corrected else color cv2.drawContours(vis_image, [box_points], 0, contour_color, thickness) overlay = vis_image.copy() fill_color = (0, 0, 255) if is_corrected else color cv2.fillPoly(overlay, [box_points], fill_color) cv2.addWeighted(overlay, alpha, vis_image, 1 - alpha, 0, vis_image) if is_corrected: text = f"{label} CV({conf:.1f})" else: text = f"{label} ({conf:.3f})" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 thickness = 2 (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) text_x = int(x_center - text_width / 2) text_y = int(y_center + text_height / 2) padding = 5 bg_rect = [ (text_x - padding, text_y - text_height - padding), (text_x + text_width + padding, text_y + padding), ] text_overlay = vis_image.copy() bg_color = (0, 0, 100) if is_corrected else (0, 0, 0) cv2.rectangle(text_overlay, bg_rect[0], bg_rect[1], bg_color, -1) text_alpha = 0.7 cv2.addWeighted(text_overlay, text_alpha, vis_image, 1 - text_alpha, 0, vis_image) text_color = (255, 255, 255) if not is_corrected else (255, 200, 200) cv2.putText( vis_image, text, (text_x, text_y), font, font_scale, text_color, thickness, ) center_color = (0, 0, 255) if is_corrected else (255, 255, 255) cv2.circle(vis_image, (int(x_center), int(y_center)), 5, center_color, -1) cv2.circle(vis_image, (int(x_center), int(y_center)), 2, (0, 0, 0), -1) except Exception as exc: print(f"Draw error for {label}: {exc}") continue dir_path = os.path.dirname(debug_save_path) if dir_path: os.makedirs(dir_path, exist_ok=True) try: if len(vis_image.shape) == 2: vis_image = cv2.cvtColor(vis_image, cv2.COLOR_GRAY2BGR) cv2.imwrite(debug_save_path, vis_image, [cv2.IMWRITE_JPEG_QUALITY, 95]) print(f"Debug saved: {debug_save_path}") print(f"Detections: {len(best_boxes)}") corrected_count = sum(1 for box in best_boxes.values() if box[5] < 0) if corrected_count > 0: print(f"Corrected order: {corrected_count}") except Exception as exc: print(f"Debug save error: {exc}") return {label: box.astype(float).tolist() for label, box in best_boxes.items()}