| 
							
								 | 
							
							# medical_rag.py | 
						
						
						
						
							 | 
							
								 | 
							
							import os | 
						
						
						
						
							 | 
							
								 | 
							
							import json | 
						
						
						
						
							 | 
							
								 | 
							
							import tiktoken | 
						
						
						
						
							 | 
							
								 | 
							
							from typing import List, Tuple | 
						
						
						
						
							 | 
							
								 | 
							
							import chromadb | 
						
						
						
						
							 | 
							
								 | 
							
							from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | 
						
						
						
						
							 | 
							
								 | 
							
							from llama_cpp import Llama | 
						
						
						
						
							 | 
							
								 | 
							
							import torch  # Добавляем импорт torch | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							class MedicalRAG: | 
						
						
						
						
							 | 
							
								 | 
							
							    def __init__( | 
						
						
						
						
							 | 
							
								 | 
							
							            self, | 
						
						
						
						
							 | 
							
								 | 
							
							            model_path: str, | 
						
						
						
						
							 | 
							
								 | 
							
							            corpus_path: str = "rag_corpus.json", | 
						
						
						
						
							 | 
							
								 | 
							
							            db_path: str = "./chroma_db", | 
						
						
						
						
							 | 
							
								 | 
							
							            embedding_model_name: str = "cointegrated/rubert-tiny2", | 
						
						
						
						
							 | 
							
								 | 
							
							            top_k: int = 3, | 
						
						
						
						
							 | 
							
								 | 
							
							            n_ctx: int = 4096, | 
						
						
						
						
							 | 
							
								 | 
							
							            n_threads: int = 4, | 
						
						
						
						
							 | 
							
								 | 
							
							            token_multiplier: int = 5, | 
						
						
						
						
							 | 
							
								 | 
							
							            n_gpu_layers: int = -1,  # Автоматическое определение слоев для GPU | 
						
						
						
						
							 | 
							
								 | 
							
							            main_gpu: int = 0,  # Основной GPU | 
						
						
						
						
							 | 
							
								 | 
							
							            tensor_split: List[float] = None,  # Разделение тензоров между GPU | 
						
						
						
						
							 | 
							
								 | 
							
							            use_gpu_for_embeddings: bool = True  # Использовать GPU для эмбеддингов | 
						
						
						
						
							 | 
							
								 | 
							
							    ): | 
						
						
						
						
							 | 
							
								 | 
							
							        self.corpus_path = corpus_path | 
						
						
						
						
							 | 
							
								 | 
							
							        self.top_k = top_k | 
						
						
						
						
							 | 
							
								 | 
							
							        self.token_multiplier = token_multiplier | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        # === Проверка доступности GPU === | 
						
						
						
						
							 | 
							
								 | 
							
							        self.has_gpu = torch.cuda.is_available() | 
						
						
						
						
							 | 
							
								 | 
							
							        if self.has_gpu: | 
						
						
						
						
							 | 
							
								 | 
							
							            print(f"✅ GPU доступен: {torch.cuda.get_device_name()}") | 
						
						
						
						
							 | 
							
								 | 
							
							            print(f"✅ Количество GPU: {torch.cuda.device_count()}") | 
						
						
						
						
							 | 
							
								 | 
							
							            print(f"✅ Память GPU: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB") | 
						
						
						
						
							 | 
							
								 | 
							
							        else: | 
						
						
						
						
							 | 
							
								 | 
							
							            print("⚠️ GPU не доступен, используется CPU") | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        # === Инициализация токенизатора === | 
						
						
						
						
							 | 
							
								 | 
							
							        print("Инициализация токенизатора...") | 
						
						
						
						
							 | 
							
								 | 
							
							        try: | 
						
						
						
						
							 | 
							
								 | 
							
							            self.encoding = tiktoken.get_encoding("cl100k_base") | 
						
						
						
						
							 | 
							
								 | 
							
							        except: | 
						
						
						
						
							 | 
							
								 | 
							
							            print("⚠️ Не удалось загрузить токенизатор, используется базовый подсчет символов") | 
						
						
						
						
							 | 
							
								 | 
							
							            self.encoding = None | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        # === Эмбеддинги с GPU === | 
						
						
						
						
							 | 
							
								 | 
							
							        print("Загрузка эмбеддинг-модели...") | 
						
						
						
						
							 | 
							
								 | 
							
							        device = "cuda" if self.has_gpu and use_gpu_for_embeddings else "cpu" | 
						
						
						
						
							 | 
							
								 | 
							
							        print(f"Эмбеддинг-модель будет использовать: {device.upper()}") | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        self.embedding_function = SentenceTransformerEmbeddingFunction( | 
						
						
						
						
							 | 
							
								 | 
							
							            model_name=embedding_model_name, | 
						
						
						
						
							 | 
							
								 | 
							
							            device=device | 
						
						
						
						
							 | 
							
								 | 
							
							        ) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        # === ChromaDB === | 
						
						
						
						
							 | 
							
								 | 
							
							        print("Инициализация ChromaDB...") | 
						
						
						
						
							 | 
							
								 | 
							
							        self.client = chromadb.PersistentClient(path=db_path) | 
						
						
						
						
							 | 
							
								 | 
							
							        self.collection = self.client.get_or_create_collection( | 
						
						
						
						
							 | 
							
								 | 
							
							            name="medical_anamnesis", | 
						
						
						
						
							 | 
							
								 | 
							
							            embedding_function=self.embedding_function, | 
						
						
						
						
							 | 
							
								 | 
							
							            metadata={"hnsw:space": "cosine"} | 
						
						
						
						
							 | 
							
								 | 
							
							        ) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        if self.collection.count() == 0: | 
						
						
						
						
							 | 
							
								 | 
							
							            print("Коллекция пуста. Загрузка данных...") | 
						
						
						
						
							 | 
							
								 | 
							
							            self._load_corpus() | 
						
						
						
						
							 | 
							
								 | 
							
							        else: | 
						
						
						
						
							 | 
							
								 | 
							
							            print(f"Коллекция уже содержит {self.collection.count()} записей.") | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        # === LLM (YandexGPT) с GPU === | 
						
						
						
						
							 | 
							
								 | 
							
							        if not os.path.exists(model_path): | 
						
						
						
						
							 | 
							
								 | 
							
							            raise FileNotFoundError( | 
						
						
						
						
							 | 
							
								 | 
							
							                f"Модель не найдена: {model_path}\n" | 
						
						
						
						
							 | 
							
								 | 
							
							                "Скачайте YandexGPT-5-Lite-8B-instruct-Q4_K_M.gguf и поместите в ./models/" | 
						
						
						
						
							 | 
							
								 | 
							
							            ) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        print("Загрузка языковой модели (YandexGPT)...") | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        # Параметры для GPU | 
						
						
						
						
							 | 
							
								 | 
							
							        gpu_params = {} | 
						
						
						
						
							 | 
							
								 | 
							
							        if self.has_gpu: | 
						
						
						
						
							 | 
							
								 | 
							
							            gpu_params.update({ | 
						
						
						
						
							 | 
							
								 | 
							
							                "n_gpu_layers": n_gpu_layers,  # -1 = все слои на GPU | 
						
						
						
						
							 | 
							
								 | 
							
							                "main_gpu": main_gpu, | 
						
						
						
						
							 | 
							
								 | 
							
							                "tensor_split": tensor_split, | 
						
						
						
						
							 | 
							
								 | 
							
							                "low_vram": False,  # Отключаем для лучшей производительности, если достаточно памяти | 
						
						
						
						
							 | 
							
								 | 
							
							                "flash_attn": True  # Включаем flash attention для ускорения | 
						
						
						
						
							 | 
							
								 | 
							
							            }) | 
						
						
						
						
							 | 
							
								 | 
							
							            print(f"Используется GPU с {n_gpu_layers} слоями") | 
						
						
						
						
							 | 
							
								 | 
							
							        else: | 
						
						
						
						
							 | 
							
								 | 
							
							            print("Используется CPU") | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        self.llm = Llama( | 
						
						
						
						
							 | 
							
								 | 
							
							            model_path=model_path, | 
						
						
						
						
							 | 
							
								 | 
							
							            n_ctx=n_ctx, | 
						
						
						
						
							 | 
							
								 | 
							
							            n_threads=n_threads, | 
						
						
						
						
							 | 
							
								 | 
							
							            verbose=False, | 
						
						
						
						
							 | 
							
								 | 
							
							            **gpu_params | 
						
						
						
						
							 | 
							
								 | 
							
							        ) | 
						
						
						
						
							 | 
							
								 | 
							
							        print("✅ Система готова к работе!") | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							    def _load_corpus(self): | 
						
						
						
						
							 | 
							
								 | 
							
							        with open(self.corpus_path, "r", encoding="utf-8") as f: | 
						
						
						
						
							 | 
							
								 | 
							
							            data = json.load(f) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        self.collection.add( | 
						
						
						
						
							 | 
							
								 | 
							
							            documents=[item["full"] for item in data], | 
						
						
						
						
							 | 
							
								 | 
							
							            metadatas=[{"short": item["short"]} for item in data], | 
						
						
						
						
							 | 
							
								 | 
							
							            ids=[f"id_{i}" for i in range(len(data))] | 
						
						
						
						
							 | 
							
								 | 
							
							        ) | 
						
						
						
						
							 | 
							
								 | 
							
							        print(f"✅ Загружено {len(data)} записей.") | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							    def count_tokens(self, text: str) -> int: | 
						
						
						
						
							 | 
							
								 | 
							
							        """Подсчет токенов в тексте""" | 
						
						
						
						
							 | 
							
								 | 
							
							        if self.encoding: | 
						
						
						
						
							 | 
							
								 | 
							
							            return len(self.encoding.encode(text)) | 
						
						
						
						
							 | 
							
								 | 
							
							        else: | 
						
						
						
						
							 | 
							
								 | 
							
							            return len(text) // 4 | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							    def build_prompt_with_token_management(self, short_note: str, max_context_tokens: int = 3000) -> Tuple[str, int]: | 
						
						
						
						
							 | 
							
								 | 
							
							        """Строит промпт с управлением токенами""" | 
						
						
						
						
							 | 
							
								 | 
							
							        examples = self.retrieve(short_note) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        system_msg = ( | 
						
						
						
						
							 | 
							
								 | 
							
							            "На основе приведённых клинических примеров напиши развёрнуто жалобы пациента, грамотно с медицинской точки зрения. " | 
						
						
						
						
							 | 
							
								 | 
							
							            "Напиши жалобы в одно предложение, одной строкой. " | 
						
						
						
						
							 | 
							
								 | 
							
							            "Не пиши вводных слов и фраз. Только жалобы пациента. " | 
						
						
						
						
							 | 
							
								 | 
							
							            "Неуместно писать диагнозы и план лечения. " | 
						
						
						
						
							 | 
							
								 | 
							
							            "Расшифруй все сокращения. " | 
						
						
						
						
							 | 
							
								 | 
							
							            "Отвечай сразу без размышлений." | 
						
						
						
						
							 | 
							
								 | 
							
							        ) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        system_tokens = self.count_tokens(system_msg) | 
						
						
						
						
							 | 
							
								 | 
							
							        note_tokens = self.count_tokens(short_note) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        available_tokens = max_context_tokens - system_tokens - note_tokens - 100 | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        selected_examples = [] | 
						
						
						
						
							 | 
							
								 | 
							
							        current_tokens = 0 | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        for example in examples: | 
						
						
						
						
							 | 
							
								 | 
							
							            example_tokens = self.count_tokens(example) | 
						
						
						
						
							 | 
							
								 | 
							
							            if current_tokens + example_tokens <= available_tokens: | 
						
						
						
						
							 | 
							
								 | 
							
							                selected_examples.append(example) | 
						
						
						
						
							 | 
							
								 | 
							
							                current_tokens += example_tokens | 
						
						
						
						
							 | 
							
								 | 
							
							            else: | 
						
						
						
						
							 | 
							
								 | 
							
							                print("⛔️ RAG: Token limit exceeded.") | 
						
						
						
						
							 | 
							
								 | 
							
							                break | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        context = "\n\n".join([f"Пример: {ex}" for ex in selected_examples]) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        user_msg = f"""Примеры развёрнутых описаний: | 
						
						
						
						
							 | 
							
								 | 
							
							{context} | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							Жалобы пациента: "{short_note}" | 
						
						
						
						
							 | 
							
								 | 
							
							""" | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        prompt = ( | 
						
						
						
						
							 | 
							
								 | 
							
							            f"<|im_start|>system\n{system_msg}<|im_end|>\n" | 
						
						
						
						
							 | 
							
								 | 
							
							            f"<|im_start|>user\n{user_msg}<|im_end|>\n" | 
						
						
						
						
							 | 
							
								 | 
							
							            "<|im_start|>assistant\n" | 
						
						
						
						
							 | 
							
								 | 
							
							        ) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        prompt_tokens = self.count_tokens(prompt) | 
						
						
						
						
							 | 
							
								 | 
							
							        return prompt, prompt_tokens | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							    def retrieve(self, query: str, n: int = None) -> List[str]: | 
						
						
						
						
							 | 
							
								 | 
							
							        n = n or self.top_k | 
						
						
						
						
							 | 
							
								 | 
							
							        results = self.collection.query(query_texts=[query], n_results=n) | 
						
						
						
						
							 | 
							
								 | 
							
							        return results["documents"][0] | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							    def generate(self, short_note: str) -> str: | 
						
						
						
						
							 | 
							
								 | 
							
							        prompt, prompt_tokens = self.build_prompt_with_token_management(short_note) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        available_tokens = 4096 - prompt_tokens - 50 | 
						
						
						
						
							 | 
							
								 | 
							
							        max_tokens = min(prompt_tokens * self.token_multiplier, available_tokens) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        print(f"📊 Токены: промпт={prompt_tokens}, макс.ответ={max_tokens}") | 
						
						
						
						
							 | 
							
								 | 
							
							        print(f"⚡️ Устройство: {'GPU' if self.has_gpu else 'CPU'}") | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        output = self.llm( | 
						
						
						
						
							 | 
							
								 | 
							
							            prompt, | 
						
						
						
						
							 | 
							
								 | 
							
							            max_tokens=max_tokens, | 
						
						
						
						
							 | 
							
								 | 
							
							            temperature=0.1, | 
						
						
						
						
							 | 
							
								 | 
							
							            stop=["<|im_end|>"], | 
						
						
						
						
							 | 
							
								 | 
							
							            echo=False | 
						
						
						
						
							 | 
							
								 | 
							
							        ) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							        result = output["choices"][0]["text"].strip() | 
						
						
						
						
							 | 
							
								 | 
							
							        return result | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							    def __call__(self, short_note: str) -> str: | 
						
						
						
						
							 | 
							
								 | 
							
							        return self.generate(short_note) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							# === Отключаем телеметрию Chroma === | 
						
						
						
						
							 | 
							
								 | 
							
							os.environ["CHROMA_TELEMETRY"] = "false" | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							import time | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							# === Запуск === | 
						
						
						
						
							 | 
							
								 | 
							
							if __name__ == "__main__": | 
						
						
						
						
							 | 
							
								 | 
							
							    rag = MedicalRAG( | 
						
						
						
						
							 | 
							
								 | 
							
							        model_path="./models/YandexGPT-5-Lite-8B-instruct-Q4_K_M.gguf", | 
						
						
						
						
							 | 
							
								 | 
							
							        n_ctx=8192, | 
						
						
						
						
							 | 
							
								 | 
							
							        n_gpu_layers=35,  # Количество слоев для GPU (можно настроить) | 
						
						
						
						
							 | 
							
								 | 
							
							        use_gpu_for_embeddings=True | 
						
						
						
						
							 | 
							
								 | 
							
							    ) | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							    # Промты для тестирования | 
						
						
						
						
							 | 
							
								 | 
							
							    test_notes = [ | 
						
						
						
						
							 | 
							
								 | 
							
							        "Кашель сухой, температура 38", | 
						
						
						
						
							 | 
							
								 | 
							
							        "А.д. 140, заложенность ушей, частичная потеря слуха", | 
						
						
						
						
							 | 
							
								 | 
							
							        "а.д. 140/80, т.36.6, ОСГО л. уха", | 
						
						
						
						
							 | 
							
								 | 
							
							        "ушные палочки 5 лет, снижение слуха 2 года" | 
						
						
						
						
							 | 
							
								 | 
							
							    ] | 
						
						
						
						
							 | 
							
								 | 
							
							 | 
						
						
						
						
							 | 
							
								 | 
							
							    for note in test_notes: | 
						
						
						
						
							 | 
							
								 | 
							
							        print(f"\n📥 Кратко: {note}") | 
						
						
						
						
							 | 
							
								 | 
							
							        t1 = time.time() | 
						
						
						
						
							 | 
							
								 | 
							
							        result = rag(note) | 
						
						
						
						
							 | 
							
								 | 
							
							        elapsed_time = time.time() - t1 | 
						
						
						
						
							 | 
							
								 | 
							
							        print(f"⏱ Время выполнения: {elapsed_time:.2f} сек") | 
						
						
						
						
							 | 
							
								 | 
							
							        if result: | 
						
						
						
						
							 | 
							
								 | 
							
							            print(f"📤 Развёрнуто:\n{result}") | 
						
						
						
						
							 | 
							
								 | 
							
							        else: | 
						
						
						
						
							 | 
							
								 | 
							
							            print("❌ Пустой ответ от модели.") | 
						
						
						
						
							 | 
							
								 | 
							
							        print("─" * 60) |