You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

232 lines
9.2 KiB

# medical_rag.py
import os
import json
import tiktoken
from typing import List, Tuple
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from llama_cpp import Llama
import torch # Добавляем импорт torch
class MedicalRAG:
def __init__(
self,
model_path: str,
corpus_path: str = "rag_corpus.json",
db_path: str = "./chroma_db",
embedding_model_name: str = "cointegrated/rubert-tiny2",
top_k: int = 3,
n_ctx: int = 4096,
n_threads: int = 4,
token_multiplier: int = 5,
n_gpu_layers: int = -1, # Автоматическое определение слоев для GPU
main_gpu: int = 0, # Основной GPU
tensor_split: List[float] = None, # Разделение тензоров между GPU
use_gpu_for_embeddings: bool = True # Использовать GPU для эмбеддингов
):
self.corpus_path = corpus_path
self.top_k = top_k
self.token_multiplier = token_multiplier
# === Проверка доступности GPU ===
self.has_gpu = torch.cuda.is_available()
if self.has_gpu:
print(f"✅ GPU доступен: {torch.cuda.get_device_name()}")
print(f"✅ Количество GPU: {torch.cuda.device_count()}")
print(f"✅ Память GPU: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB")
else:
print(" GPU не доступен, используется CPU")
# === Инициализация токенизатора ===
print("Инициализация токенизатора...")
try:
self.encoding = tiktoken.get_encoding("cl100k_base")
except:
print(" Не удалось загрузить токенизатор, используется базовый подсчет символов")
self.encoding = None
# === Эмбеддинги с GPU ===
print("Загрузка эмбеддинг-модели...")
device = "cuda" if self.has_gpu and use_gpu_for_embeddings else "cpu"
print(f"Эмбеддинг-модель будет использовать: {device.upper()}")
self.embedding_function = SentenceTransformerEmbeddingFunction(
model_name=embedding_model_name,
device=device
)
# === ChromaDB ===
print("Инициализация ChromaDB...")
self.client = chromadb.PersistentClient(path=db_path)
self.collection = self.client.get_or_create_collection(
name="medical_anamnesis",
embedding_function=self.embedding_function,
metadata={"hnsw:space": "cosine"}
)
if self.collection.count() == 0:
print("Коллекция пуста. Загрузка данных...")
self._load_corpus()
else:
print(f"Коллекция уже содержит {self.collection.count()} записей.")
# === LLM (YandexGPT) с GPU ===
if not os.path.exists(model_path):
raise FileNotFoundError(
f"Модель не найдена: {model_path}\n"
"Скачайте YandexGPT-5-Lite-8B-instruct-Q4_K_M.gguf и поместите в ./models/"
)
print("Загрузка языковой модели (YandexGPT)...")
# Параметры для GPU
gpu_params = {}
if self.has_gpu:
gpu_params.update({
"n_gpu_layers": n_gpu_layers, # -1 = все слои на GPU
"main_gpu": main_gpu,
"tensor_split": tensor_split,
"low_vram": False, # Отключаем для лучшей производительности, если достаточно памяти
"flash_attn": True # Включаем flash attention для ускорения
})
print(f"Используется GPU с {n_gpu_layers} слоями")
else:
print("Используется CPU")
self.llm = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_threads=n_threads,
verbose=False,
**gpu_params
)
print("✅ Система готова к работе!")
def _load_corpus(self):
with open(self.corpus_path, "r", encoding="utf-8") as f:
data = json.load(f)
self.collection.add(
documents=[item["full"] for item in data],
metadatas=[{"short": item["short"]} for item in data],
ids=[f"id_{i}" for i in range(len(data))]
)
print(f"✅ Загружено {len(data)} записей.")
def count_tokens(self, text: str) -> int:
"""Подсчет токенов в тексте"""
if self.encoding:
return len(self.encoding.encode(text))
else:
return len(text) // 4
def build_prompt_with_token_management(self, short_note: str, max_context_tokens: int = 3000) -> Tuple[str, int]:
"""Строит промпт с управлением токенами"""
examples = self.retrieve(short_note)
system_msg = (
"На основе приведённых клинических примеров напиши развёрнуто жалобы пациента, грамотно с медицинской точки зрения. "
"Напиши жалобы в одно предложение, одной строкой. "
"Не пиши вводных слов и фраз. Только жалобы пациента. "
"Неуместно писать диагнозы и план лечения. "
"Расшифруй все сокращения. "
"Отвечай сразу без размышлений."
)
system_tokens = self.count_tokens(system_msg)
note_tokens = self.count_tokens(short_note)
available_tokens = max_context_tokens - system_tokens - note_tokens - 100
selected_examples = []
current_tokens = 0
for example in examples:
example_tokens = self.count_tokens(example)
if current_tokens + example_tokens <= available_tokens:
selected_examples.append(example)
current_tokens += example_tokens
else:
print(" RAG: Token limit exceeded.")
break
context = "\n\n".join([f"Пример: {ex}" for ex in selected_examples])
user_msg = f"""Примеры развёрнутых описаний:
{context}
Жалобы пациента: "{short_note}"
"""
prompt = (
f"<|im_start|>system\n{system_msg}<|im_end|>\n"
f"<|im_start|>user\n{user_msg}<|im_end|>\n"
"<|im_start|>assistant\n"
)
prompt_tokens = self.count_tokens(prompt)
return prompt, prompt_tokens
def retrieve(self, query: str, n: int = None) -> List[str]:
n = n or self.top_k
results = self.collection.query(query_texts=[query], n_results=n)
return results["documents"][0]
def generate(self, short_note: str) -> str:
prompt, prompt_tokens = self.build_prompt_with_token_management(short_note)
available_tokens = 4096 - prompt_tokens - 50
max_tokens = min(prompt_tokens * self.token_multiplier, available_tokens)
print(f"📊 Токены: промпт={prompt_tokens}, макс.ответ={max_tokens}")
print(f" Устройство: {'GPU' if self.has_gpu else 'CPU'}")
output = self.llm(
prompt,
max_tokens=max_tokens,
temperature=0.1,
stop=["<|im_end|>"],
echo=False
)
result = output["choices"][0]["text"].strip()
return result
def __call__(self, short_note: str) -> str:
return self.generate(short_note)
# === Отключаем телеметрию Chroma ===
os.environ["CHROMA_TELEMETRY"] = "false"
import time
# === Запуск ===
if __name__ == "__main__":
rag = MedicalRAG(
model_path="./models/YandexGPT-5-Lite-8B-instruct-Q4_K_M.gguf",
n_ctx=8192,
n_gpu_layers=35, # Количество слоев для GPU (можно настроить)
use_gpu_for_embeddings=True
)
# Промты для тестирования
test_notes = [
"Кашель сухой, температура 38",
"А.д. 140, заложенность ушей, частичная потеря слуха",
"а.д. 140/80, т.36.6, ОСГО л. уха",
"ушные палочки 5 лет, снижение слуха 2 года"
]
for note in test_notes:
print(f"\n📥 Кратко: {note}")
t1 = time.time()
result = rag(note)
elapsed_time = time.time() - t1
print(f"⏱ Время выполнения: {elapsed_time:.2f} сек")
if result:
print(f"📤 Развёрнуто:\n{result}")
else:
print("❌ Пустой ответ от модели.")
print("" * 60)