# medical_rag.py import os import json import tiktoken from typing import List, Tuple import chromadb from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction from llama_cpp import Llama import torch # Добавляем импорт torch class MedicalRAG: def __init__( self, model_path: str, corpus_path: str = "rag_corpus.json", db_path: str = "./chroma_db", embedding_model_name: str = "cointegrated/rubert-tiny2", top_k: int = 3, n_ctx: int = 4096, n_threads: int = 4, token_multiplier: int = 5, n_gpu_layers: int = -1, # Автоматическое определение слоев для GPU main_gpu: int = 0, # Основной GPU tensor_split: List[float] = None, # Разделение тензоров между GPU use_gpu_for_embeddings: bool = True # Использовать GPU для эмбеддингов ): self.corpus_path = corpus_path self.top_k = top_k self.token_multiplier = token_multiplier # === Проверка доступности GPU === self.has_gpu = torch.cuda.is_available() if self.has_gpu: print(f"✅ GPU доступен: {torch.cuda.get_device_name()}") print(f"✅ Количество GPU: {torch.cuda.device_count()}") print(f"✅ Память GPU: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB") else: print("⚠️ GPU не доступен, используется CPU") # === Инициализация токенизатора === print("Инициализация токенизатора...") try: self.encoding = tiktoken.get_encoding("cl100k_base") except: print("⚠️ Не удалось загрузить токенизатор, используется базовый подсчет символов") self.encoding = None # === Эмбеддинги с GPU === print("Загрузка эмбеддинг-модели...") device = "cuda" if self.has_gpu and use_gpu_for_embeddings else "cpu" print(f"Эмбеддинг-модель будет использовать: {device.upper()}") self.embedding_function = SentenceTransformerEmbeddingFunction( model_name=embedding_model_name, device=device ) # === ChromaDB === print("Инициализация ChromaDB...") self.client = chromadb.PersistentClient(path=db_path) self.collection = self.client.get_or_create_collection( name="medical_anamnesis", embedding_function=self.embedding_function, metadata={"hnsw:space": "cosine"} ) if self.collection.count() == 0: print("Коллекция пуста. Загрузка данных...") self._load_corpus() else: print(f"Коллекция уже содержит {self.collection.count()} записей.") # === LLM (YandexGPT) с GPU === if not os.path.exists(model_path): raise FileNotFoundError( f"Модель не найдена: {model_path}\n" "Скачайте YandexGPT-5-Lite-8B-instruct-Q4_K_M.gguf и поместите в ./models/" ) print("Загрузка языковой модели (YandexGPT)...") # Параметры для GPU gpu_params = {} if self.has_gpu: gpu_params.update({ "n_gpu_layers": n_gpu_layers, # -1 = все слои на GPU "main_gpu": main_gpu, "tensor_split": tensor_split, "low_vram": False, # Отключаем для лучшей производительности, если достаточно памяти "flash_attn": True # Включаем flash attention для ускорения }) print(f"Используется GPU с {n_gpu_layers} слоями") else: print("Используется CPU") self.llm = Llama( model_path=model_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False, **gpu_params ) print("✅ Система готова к работе!") def _load_corpus(self): with open(self.corpus_path, "r", encoding="utf-8") as f: data = json.load(f) self.collection.add( documents=[item["full"] for item in data], metadatas=[{"short": item["short"]} for item in data], ids=[f"id_{i}" for i in range(len(data))] ) print(f"✅ Загружено {len(data)} записей.") def count_tokens(self, text: str) -> int: """Подсчет токенов в тексте""" if self.encoding: return len(self.encoding.encode(text)) else: return len(text) // 4 def build_prompt_with_token_management(self, short_note: str, max_context_tokens: int = 3000) -> Tuple[str, int]: """Строит промпт с управлением токенами""" examples = self.retrieve(short_note) system_msg = ( "На основе приведённых клинических примеров напиши развёрнуто жалобы пациента, грамотно с медицинской точки зрения. " "Напиши жалобы в одно предложение, одной строкой. " "Не пиши вводных слов и фраз. Только жалобы пациента. " "Неуместно писать диагнозы и план лечения. " "Расшифруй все сокращения. " "Отвечай сразу без размышлений." ) system_tokens = self.count_tokens(system_msg) note_tokens = self.count_tokens(short_note) available_tokens = max_context_tokens - system_tokens - note_tokens - 100 selected_examples = [] current_tokens = 0 for example in examples: example_tokens = self.count_tokens(example) if current_tokens + example_tokens <= available_tokens: selected_examples.append(example) current_tokens += example_tokens else: print("⛔️ RAG: Token limit exceeded.") break context = "\n\n".join([f"Пример: {ex}" for ex in selected_examples]) user_msg = f"""Примеры развёрнутых описаний: {context} Жалобы пациента: "{short_note}" """ prompt = ( f"<|im_start|>system\n{system_msg}<|im_end|>\n" f"<|im_start|>user\n{user_msg}<|im_end|>\n" "<|im_start|>assistant\n" ) prompt_tokens = self.count_tokens(prompt) return prompt, prompt_tokens def retrieve(self, query: str, n: int = None) -> List[str]: n = n or self.top_k results = self.collection.query(query_texts=[query], n_results=n) return results["documents"][0] def generate(self, short_note: str) -> str: prompt, prompt_tokens = self.build_prompt_with_token_management(short_note) available_tokens = 4096 - prompt_tokens - 50 max_tokens = min(prompt_tokens * self.token_multiplier, available_tokens) print(f"📊 Токены: промпт={prompt_tokens}, макс.ответ={max_tokens}") print(f"⚡️ Устройство: {'GPU' if self.has_gpu else 'CPU'}") output = self.llm( prompt, max_tokens=max_tokens, temperature=0.1, stop=["<|im_end|>"], echo=False ) result = output["choices"][0]["text"].strip() return result def __call__(self, short_note: str) -> str: return self.generate(short_note) # === Отключаем телеметрию Chroma === os.environ["CHROMA_TELEMETRY"] = "false" import time # === Запуск === if __name__ == "__main__": rag = MedicalRAG( model_path="./models/YandexGPT-5-Lite-8B-instruct-Q4_K_M.gguf", n_ctx=8192, n_gpu_layers=35, # Количество слоев для GPU (можно настроить) use_gpu_for_embeddings=True ) # Промты для тестирования test_notes = [ "Кашель сухой, температура 38", "А.д. 140, заложенность ушей, частичная потеря слуха", "а.д. 140/80, т.36.6, ОСГО л. уха", "ушные палочки 5 лет, снижение слуха 2 года" ] for note in test_notes: print(f"\n📥 Кратко: {note}") t1 = time.time() result = rag(note) elapsed_time = time.time() - t1 print(f"⏱ Время выполнения: {elapsed_time:.2f} сек") if result: print(f"📤 Развёрнуто:\n{result}") else: print("❌ Пустой ответ от модели.") print("─" * 60)