You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
4.9 KiB

"""Подключение к PostgreSQL и ORM-сессии.
Основная БД — `clinic_tests`.
Опциональная вторая БД — `hr_bot_test` (когда HR_AUTH=1).
"""
from __future__ import annotations
import os
import threading
from typing import Optional
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from sqlalchemy.pool import QueuePool
_engine_lock = threading.Lock()
_session_lock = threading.Lock()
_hr_engine_lock = threading.Lock()
_engine: Optional[Engine] = None
_session_factory: Optional[scoped_session] = None
_hr_engine: Optional[Engine] = None
# ─── URL helpers ─────────────────────────────────────────────────────────────
def get_database_url() -> str:
if db_url := os.environ.get('DATABASE_URL'):
return db_url.strip()
host = os.environ.get('DB_HOST', 'localhost')
port = os.environ.get('DB_PORT', '5432')
name = os.environ.get('DB_NAME', 'clinic_tests')
user = os.environ.get('DB_USER', 'hr_bot_user')
password = os.environ.get('DB_PASSWORD', 'hrbot123')
return f'postgresql+psycopg2://{user}:{password}@{host}:{port}/{name}'
def _hr_auth_enabled() -> bool:
return (os.environ.get('HR_AUTH') or '').strip().lower() in ('1', 'true', 'yes', 'on')
def get_hr_database_url() -> Optional[str]:
if not _hr_auth_enabled():
return None
url = (os.environ.get('HR_DATABASE_URL') or '').strip()
return url or None
# ─── Main engine ─────────────────────────────────────────────────────────────
def get_engine() -> Engine:
global _engine
if _engine is not None:
return _engine
with _engine_lock:
if _engine is None:
_engine = create_engine(
get_database_url(),
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
)
return _engine
# ─── Scoped session ──────────────────────────────────────────────────────────
def get_session() -> Session:
"""Возвращает ORM-сессию для текущего потока (scoped_session)."""
global _session_factory
if _session_factory is None:
with _session_lock:
if _session_factory is None:
# Инициализируем engine до захвата session_lock, чтобы не было вложенных блокировок
engine = get_engine()
_session_factory = scoped_session(
sessionmaker(bind=engine, autoflush=True, autocommit=False)
)
return _session_factory # type: ignore[return-value]
def remove_session() -> None:
"""Освобождает сессию для текущего потока. Вызывается в teardown_appcontext."""
if _session_factory is not None:
_session_factory.remove()
# ─── HR engine (raw SQL only) ────────────────────────────────────────────────
def get_hr_engine() -> Optional[Engine]:
if not _hr_auth_enabled():
return None
global _hr_engine
if _hr_engine is not None:
return _hr_engine
url = get_hr_database_url()
if not url:
return None
with _hr_engine_lock:
if _hr_engine is None:
_hr_engine = create_engine(
url,
poolclass=QueuePool,
pool_size=3,
max_overflow=5,
pool_pre_ping=True,
)
return _hr_engine
# ─── Smoke check ─────────────────────────────────────────────────────────────
def ping() -> dict:
out: dict = {'main': 'unknown'}
try:
with get_engine().connect() as conn:
conn.exec_driver_sql('SELECT 1')
out['main'] = 'ok'
except Exception as e:
out['main'] = f'error: {type(e).__name__}: {e}'
if _hr_auth_enabled():
out['hr'] = 'unknown'
try:
eng = get_hr_engine()
if eng is None:
out['hr'] = 'disabled (HR_DATABASE_URL not set)'
else:
with eng.connect() as conn:
conn.exec_driver_sql('SELECT 1')
out['hr'] = 'ok'
except Exception as e:
out['hr'] = f'error: {type(e).__name__}: {e}'
return out