import os from dataclasses import dataclass def _as_bool(value: str, default: bool) -> bool: if value is None: return default return value.strip() not in {"0", "false", "False", ""} def _env_path(key: str, default_path: str) -> str: return os.path.abspath(os.environ.get(key, default_path)) @dataclass(frozen=True) class ModelPaths: classification: str projection: str detection: str @dataclass(frozen=True) class AppConfig: models: ModelPaths postprocessing: bool interpolate_missing: bool results_dir: str detect_eps_px: float detect_min_len: int detect_zero_gap: int detect_edge_zero_extend: int merge_max_gap: int min_cobb_main: float min_cobb_second_abs: float min_cobb_second_rel: float min_len_struct_small: int min_len_struct_large: int min_apex_margin_small: int min_apex_margin_large: int def load_config() -> AppConfig: base_dir = os.path.dirname(__file__) default_results = os.path.join(base_dir, "results") models = ModelPaths( classification=_env_path( "MODEL_CLASSIFICATION", os.path.join(base_dir, "models", "classification.pt"), ), projection=_env_path( "MODEL_PROJECTION", os.path.join(base_dir, "models", "projection.pt"), ), detection=_env_path( "MODEL_DETECTION", os.path.join(base_dir, "models", "1_best_yolo8(L).pt"), ), ) return AppConfig( models=models, postprocessing=_as_bool(os.environ.get("POSTPROCESSING"), True), interpolate_missing=_as_bool(os.environ.get("INTERPOLATE_MISSING"), False), results_dir=_env_path("RESULTS_DIR", default_results), detect_eps_px=float(os.environ.get("DETECT_EPS_PX", 1.0)), detect_min_len=int(os.environ.get("DETECT_MIN_LEN", 2)), detect_zero_gap=int(os.environ.get("DETECT_ZERO_GAP", 2)), detect_edge_zero_extend=int(os.environ.get("DETECT_EDGE_ZERO_EXTEND", 3)), merge_max_gap=int(os.environ.get("MERGE_MAX_GAP", 1)), min_cobb_main=float(os.environ.get("MIN_COBB_MAIN", 1.0)), min_cobb_second_abs=float(os.environ.get("MIN_COBB_SECOND_ABS", 5.0)), min_cobb_second_rel=float(os.environ.get("MIN_COBB_SECOND_REL", 0.35)), min_len_struct_small=int(os.environ.get("MIN_LEN_STRUCT_SMALL", 3)), min_len_struct_large=int(os.environ.get("MIN_LEN_STRUCT_LARGE", 4)), min_apex_margin_small=int(os.environ.get("MIN_APEX_MARGIN_SMALL", 1)), min_apex_margin_large=int(os.environ.get("MIN_APEX_MARGIN_LARGE", 2)), )