|
|
# medical_rag.py |
|
|
import os |
|
|
import json |
|
|
import tiktoken |
|
|
from typing import List, Tuple |
|
|
import chromadb |
|
|
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction |
|
|
from llama_cpp import Llama |
|
|
import torch # Добавляем импорт torch |
|
|
|
|
|
|
|
|
class MedicalRAG: |
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
corpus_path: str = "rag_corpus.json", |
|
|
db_path: str = "./chroma_db", |
|
|
embedding_model_name: str = "cointegrated/rubert-tiny2", |
|
|
top_k: int = 3, |
|
|
n_ctx: int = 4096, |
|
|
n_threads: int = 4, |
|
|
token_multiplier: int = 5, |
|
|
n_gpu_layers: int = -1, # Автоматическое определение слоев для GPU |
|
|
main_gpu: int = 0, # Основной GPU |
|
|
tensor_split: List[float] = None, # Разделение тензоров между GPU |
|
|
use_gpu_for_embeddings: bool = True # Использовать GPU для эмбеддингов |
|
|
): |
|
|
self.corpus_path = corpus_path |
|
|
self.top_k = top_k |
|
|
self.token_multiplier = token_multiplier |
|
|
|
|
|
# === Проверка доступности GPU === |
|
|
self.has_gpu = torch.cuda.is_available() |
|
|
if self.has_gpu: |
|
|
print(f"✅ GPU доступен: {torch.cuda.get_device_name()}") |
|
|
print(f"✅ Количество GPU: {torch.cuda.device_count()}") |
|
|
print(f"✅ Память GPU: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB") |
|
|
else: |
|
|
print("⚠️ GPU не доступен, используется CPU") |
|
|
|
|
|
# === Инициализация токенизатора === |
|
|
print("Инициализация токенизатора...") |
|
|
try: |
|
|
self.encoding = tiktoken.get_encoding("cl100k_base") |
|
|
except: |
|
|
print("⚠️ Не удалось загрузить токенизатор, используется базовый подсчет символов") |
|
|
self.encoding = None |
|
|
|
|
|
# === Эмбеддинги с GPU === |
|
|
print("Загрузка эмбеддинг-модели...") |
|
|
device = "cuda" if self.has_gpu and use_gpu_for_embeddings else "cpu" |
|
|
print(f"Эмбеддинг-модель будет использовать: {device.upper()}") |
|
|
|
|
|
self.embedding_function = SentenceTransformerEmbeddingFunction( |
|
|
model_name=embedding_model_name, |
|
|
device=device |
|
|
) |
|
|
|
|
|
# === ChromaDB === |
|
|
print("Инициализация ChromaDB...") |
|
|
self.client = chromadb.PersistentClient(path=db_path) |
|
|
self.collection = self.client.get_or_create_collection( |
|
|
name="medical_anamnesis", |
|
|
embedding_function=self.embedding_function, |
|
|
metadata={"hnsw:space": "cosine"} |
|
|
) |
|
|
|
|
|
if self.collection.count() == 0: |
|
|
print("Коллекция пуста. Загрузка данных...") |
|
|
self._load_corpus() |
|
|
else: |
|
|
print(f"Коллекция уже содержит {self.collection.count()} записей.") |
|
|
|
|
|
# === LLM (YandexGPT) с GPU === |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError( |
|
|
f"Модель не найдена: {model_path}\n" |
|
|
"Скачайте YandexGPT-5-Lite-8B-instruct-Q4_K_M.gguf и поместите в ./models/" |
|
|
) |
|
|
|
|
|
print("Загрузка языковой модели (YandexGPT)...") |
|
|
|
|
|
# Параметры для GPU |
|
|
gpu_params = {} |
|
|
if self.has_gpu: |
|
|
gpu_params.update({ |
|
|
"n_gpu_layers": n_gpu_layers, # -1 = все слои на GPU |
|
|
"main_gpu": main_gpu, |
|
|
"tensor_split": tensor_split, |
|
|
"low_vram": False, # Отключаем для лучшей производительности, если достаточно памяти |
|
|
"flash_attn": True # Включаем flash attention для ускорения |
|
|
}) |
|
|
print(f"Используется GPU с {n_gpu_layers} слоями") |
|
|
else: |
|
|
print("Используется CPU") |
|
|
|
|
|
self.llm = Llama( |
|
|
model_path=model_path, |
|
|
n_ctx=n_ctx, |
|
|
n_threads=n_threads, |
|
|
verbose=False, |
|
|
**gpu_params |
|
|
) |
|
|
print("✅ Система готова к работе!") |
|
|
|
|
|
def _load_corpus(self): |
|
|
with open(self.corpus_path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
self.collection.add( |
|
|
documents=[item["full"] for item in data], |
|
|
metadatas=[{"short": item["short"]} for item in data], |
|
|
ids=[f"id_{i}" for i in range(len(data))] |
|
|
) |
|
|
print(f"✅ Загружено {len(data)} записей.") |
|
|
|
|
|
def count_tokens(self, text: str) -> int: |
|
|
"""Подсчет токенов в тексте""" |
|
|
if self.encoding: |
|
|
return len(self.encoding.encode(text)) |
|
|
else: |
|
|
return len(text) // 4 |
|
|
|
|
|
def build_prompt_with_token_management(self, short_note: str, max_context_tokens: int = 3000) -> Tuple[str, int]: |
|
|
"""Строит промпт с управлением токенами""" |
|
|
examples = self.retrieve(short_note) |
|
|
|
|
|
system_msg = ( |
|
|
"На основе приведённых клинических примеров напиши развёрнуто жалобы пациента, грамотно с медицинской точки зрения. " |
|
|
"Напиши жалобы в одно предложение, одной строкой. " |
|
|
"Не пиши вводных слов и фраз. Только жалобы пациента. " |
|
|
"Неуместно писать диагнозы и план лечения. " |
|
|
"Расшифруй все сокращения. " |
|
|
"Отвечай сразу без размышлений." |
|
|
) |
|
|
|
|
|
system_tokens = self.count_tokens(system_msg) |
|
|
note_tokens = self.count_tokens(short_note) |
|
|
|
|
|
available_tokens = max_context_tokens - system_tokens - note_tokens - 100 |
|
|
|
|
|
selected_examples = [] |
|
|
current_tokens = 0 |
|
|
|
|
|
for example in examples: |
|
|
example_tokens = self.count_tokens(example) |
|
|
if current_tokens + example_tokens <= available_tokens: |
|
|
selected_examples.append(example) |
|
|
current_tokens += example_tokens |
|
|
else: |
|
|
print("⛔️ RAG: Token limit exceeded.") |
|
|
break |
|
|
|
|
|
context = "\n\n".join([f"Пример: {ex}" for ex in selected_examples]) |
|
|
|
|
|
user_msg = f"""Примеры развёрнутых описаний: |
|
|
{context} |
|
|
|
|
|
Жалобы пациента: "{short_note}" |
|
|
""" |
|
|
|
|
|
prompt = ( |
|
|
f"<|im_start|>system\n{system_msg}<|im_end|>\n" |
|
|
f"<|im_start|>user\n{user_msg}<|im_end|>\n" |
|
|
"<|im_start|>assistant\n" |
|
|
) |
|
|
|
|
|
prompt_tokens = self.count_tokens(prompt) |
|
|
return prompt, prompt_tokens |
|
|
|
|
|
def retrieve(self, query: str, n: int = None) -> List[str]: |
|
|
n = n or self.top_k |
|
|
results = self.collection.query(query_texts=[query], n_results=n) |
|
|
return results["documents"][0] |
|
|
|
|
|
def generate(self, short_note: str) -> str: |
|
|
prompt, prompt_tokens = self.build_prompt_with_token_management(short_note) |
|
|
|
|
|
available_tokens = 4096 - prompt_tokens - 50 |
|
|
max_tokens = min(prompt_tokens * self.token_multiplier, available_tokens) |
|
|
|
|
|
print(f"📊 Токены: промпт={prompt_tokens}, макс.ответ={max_tokens}") |
|
|
print(f"⚡️ Устройство: {'GPU' if self.has_gpu else 'CPU'}") |
|
|
|
|
|
output = self.llm( |
|
|
prompt, |
|
|
max_tokens=max_tokens, |
|
|
temperature=0.1, |
|
|
stop=["<|im_end|>"], |
|
|
echo=False |
|
|
) |
|
|
|
|
|
result = output["choices"][0]["text"].strip() |
|
|
return result |
|
|
|
|
|
def __call__(self, short_note: str) -> str: |
|
|
return self.generate(short_note) |
|
|
|
|
|
|
|
|
# === Отключаем телеметрию Chroma === |
|
|
os.environ["CHROMA_TELEMETRY"] = "false" |
|
|
|
|
|
import time |
|
|
|
|
|
# === Запуск === |
|
|
if __name__ == "__main__": |
|
|
rag = MedicalRAG( |
|
|
model_path="./models/YandexGPT-5-Lite-8B-instruct-Q4_K_M.gguf", |
|
|
n_ctx=8192, |
|
|
n_gpu_layers=35, # Количество слоев для GPU (можно настроить) |
|
|
use_gpu_for_embeddings=True |
|
|
) |
|
|
|
|
|
# Промты для тестирования |
|
|
test_notes = [ |
|
|
"Кашель сухой, температура 38", |
|
|
"А.д. 140, заложенность ушей, частичная потеря слуха", |
|
|
"а.д. 140/80, т.36.6, ОСГО л. уха", |
|
|
"ушные палочки 5 лет, снижение слуха 2 года" |
|
|
] |
|
|
|
|
|
for note in test_notes: |
|
|
print(f"\n📥 Кратко: {note}") |
|
|
t1 = time.time() |
|
|
result = rag(note) |
|
|
elapsed_time = time.time() - t1 |
|
|
print(f"⏱ Время выполнения: {elapsed_time:.2f} сек") |
|
|
if result: |
|
|
print(f"📤 Развёрнуто:\n{result}") |
|
|
else: |
|
|
print("❌ Пустой ответ от модели.") |
|
|
print("─" * 60) |