You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.6 KiB
78 lines
2.6 KiB
import os |
|
from dataclasses import dataclass |
|
|
|
|
|
def _as_bool(value: str, default: bool) -> bool: |
|
if value is None: |
|
return default |
|
return value.strip() not in {"0", "false", "False", ""} |
|
|
|
|
|
def _env_path(key: str, default_path: str) -> str: |
|
return os.path.abspath(os.environ.get(key, default_path)) |
|
|
|
|
|
@dataclass(frozen=True) |
|
class ModelPaths: |
|
classification: str |
|
projection: str |
|
detection: str |
|
|
|
|
|
@dataclass(frozen=True) |
|
class AppConfig: |
|
models: ModelPaths |
|
postprocessing: bool |
|
interpolate_missing: bool |
|
results_dir: str |
|
detect_eps_px: float |
|
detect_min_len: int |
|
detect_zero_gap: int |
|
detect_edge_zero_extend: int |
|
merge_max_gap: int |
|
min_cobb_main: float |
|
min_cobb_second_abs: float |
|
min_cobb_second_rel: float |
|
min_len_struct_small: int |
|
min_len_struct_large: int |
|
min_apex_margin_small: int |
|
min_apex_margin_large: int |
|
|
|
|
|
def load_config() -> AppConfig: |
|
base_dir = os.path.dirname(__file__) |
|
default_results = os.path.join(base_dir, "results") |
|
|
|
models = ModelPaths( |
|
classification=_env_path( |
|
"MODEL_CLASSIFICATION", |
|
os.path.join(base_dir, "models", "classification.pt"), |
|
), |
|
projection=_env_path( |
|
"MODEL_PROJECTION", |
|
os.path.join(base_dir, "models", "projection.pt"), |
|
), |
|
detection=_env_path( |
|
"MODEL_DETECTION", |
|
os.path.join(base_dir, "models", "1_best_yolo8(L).pt"), |
|
), |
|
) |
|
|
|
return AppConfig( |
|
models=models, |
|
postprocessing=_as_bool(os.environ.get("POSTPROCESSING"), True), |
|
interpolate_missing=_as_bool(os.environ.get("INTERPOLATE_MISSING"), False), |
|
results_dir=_env_path("RESULTS_DIR", default_results), |
|
detect_eps_px=float(os.environ.get("DETECT_EPS_PX", 1.0)), |
|
detect_min_len=int(os.environ.get("DETECT_MIN_LEN", 2)), |
|
detect_zero_gap=int(os.environ.get("DETECT_ZERO_GAP", 2)), |
|
detect_edge_zero_extend=int(os.environ.get("DETECT_EDGE_ZERO_EXTEND", 3)), |
|
merge_max_gap=int(os.environ.get("MERGE_MAX_GAP", 1)), |
|
min_cobb_main=float(os.environ.get("MIN_COBB_MAIN", 1.0)), |
|
min_cobb_second_abs=float(os.environ.get("MIN_COBB_SECOND_ABS", 5.0)), |
|
min_cobb_second_rel=float(os.environ.get("MIN_COBB_SECOND_REL", 0.35)), |
|
min_len_struct_small=int(os.environ.get("MIN_LEN_STRUCT_SMALL", 3)), |
|
min_len_struct_large=int(os.environ.get("MIN_LEN_STRUCT_LARGE", 4)), |
|
min_apex_margin_small=int(os.environ.get("MIN_APEX_MARGIN_SMALL", 1)), |
|
min_apex_margin_large=int(os.environ.get("MIN_APEX_MARGIN_LARGE", 2)), |
|
)
|
|
|