You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.6 KiB

import os
from dataclasses import dataclass
def _as_bool(value: str, default: bool) -> bool:
if value is None:
return default
return value.strip() not in {"0", "false", "False", ""}
def _env_path(key: str, default_path: str) -> str:
return os.path.abspath(os.environ.get(key, default_path))
@dataclass(frozen=True)
class ModelPaths:
classification: str
projection: str
detection: str
@dataclass(frozen=True)
class AppConfig:
models: ModelPaths
postprocessing: bool
interpolate_missing: bool
results_dir: str
detect_eps_px: float
detect_min_len: int
detect_zero_gap: int
detect_edge_zero_extend: int
merge_max_gap: int
min_cobb_main: float
min_cobb_second_abs: float
min_cobb_second_rel: float
min_len_struct_small: int
min_len_struct_large: int
min_apex_margin_small: int
min_apex_margin_large: int
def load_config() -> AppConfig:
base_dir = os.path.dirname(__file__)
default_results = os.path.join(base_dir, "results")
models = ModelPaths(
classification=_env_path(
"MODEL_CLASSIFICATION",
os.path.join(base_dir, "models", "classification.pt"),
),
projection=_env_path(
"MODEL_PROJECTION",
os.path.join(base_dir, "models", "projection.pt"),
),
detection=_env_path(
"MODEL_DETECTION",
os.path.join(base_dir, "models", "1_best_yolo8(L).pt"),
),
)
return AppConfig(
models=models,
postprocessing=_as_bool(os.environ.get("POSTPROCESSING"), True),
interpolate_missing=_as_bool(os.environ.get("INTERPOLATE_MISSING"), False),
results_dir=_env_path("RESULTS_DIR", default_results),
detect_eps_px=float(os.environ.get("DETECT_EPS_PX", 1.0)),
detect_min_len=int(os.environ.get("DETECT_MIN_LEN", 2)),
detect_zero_gap=int(os.environ.get("DETECT_ZERO_GAP", 2)),
detect_edge_zero_extend=int(os.environ.get("DETECT_EDGE_ZERO_EXTEND", 3)),
merge_max_gap=int(os.environ.get("MERGE_MAX_GAP", 1)),
min_cobb_main=float(os.environ.get("MIN_COBB_MAIN", 1.0)),
min_cobb_second_abs=float(os.environ.get("MIN_COBB_SECOND_ABS", 5.0)),
min_cobb_second_rel=float(os.environ.get("MIN_COBB_SECOND_REL", 0.35)),
min_len_struct_small=int(os.environ.get("MIN_LEN_STRUCT_SMALL", 3)),
min_len_struct_large=int(os.environ.get("MIN_LEN_STRUCT_LARGE", 4)),
min_apex_margin_small=int(os.environ.get("MIN_APEX_MARGIN_SMALL", 1)),
min_apex_margin_large=int(os.environ.get("MIN_APEX_MARGIN_LARGE", 2)),
)