import logging from datetime import datetime, timezone import chromadb from services.embeddings import EmbeddingService logger = logging.getLogger(__name__) COLLECTION_NAME = "operators_wiki" class VectorStoreService: def __init__(self, persist_dir: str, embedding_service: EmbeddingService): self.client = chromadb.PersistentClient(path=persist_dir) self.embedding_service = embedding_service self.collection = self.client.get_or_create_collection( name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) logger.info("ChromaDB collection '%s': %d items", COLLECTION_NAME, self.collection.count()) def add_document( self, document_id: str, document_name: str, file_type: str, chunks: list[dict], ) -> int: if not chunks: return 0 texts = [c["text"] for c in chunks] embeddings = self.embedding_service.embed_documents(texts) ids = [] metadatas = [] now = datetime.now(timezone.utc).isoformat() for i, chunk in enumerate(chunks): ids.append(f"{document_id}_chunk_{i}") metadatas.append({ "document_id": document_id, "document_name": document_name, "file_type": file_type, "section": chunk.get("section", ""), "page_number": chunk.get("page_number", 0), "chunk_index": i, "created_at": now, }) self.collection.add( ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas, ) logger.info("Added %d chunks for document '%s'", len(chunks), document_name) return len(chunks) def query( self, query_text: str, top_k: int = 5, document_ids: list[str] | None = None, ) -> list[dict]: query_embedding = self.embedding_service.embed_query(query_text) where_filter = None if document_ids: if len(document_ids) == 1: where_filter = {"document_id": document_ids[0]} else: where_filter = {"document_id": {"$in": document_ids}} results = self.collection.query( query_embeddings=[query_embedding], n_results=top_k, where=where_filter, include=["documents", "metadatas", "distances"], ) items = [] if results["ids"] and results["ids"][0]: for i, chunk_id in enumerate(results["ids"][0]): items.append({ "chunk_id": chunk_id, "text": results["documents"][0][i], "metadata": results["metadatas"][0][i], "distance": results["distances"][0][i], "relevance_score": 1 - results["distances"][0][i], }) return items def delete_document(self, document_id: str) -> int: existing = self.collection.get(where={"document_id": document_id}, include=[]) count = len(existing["ids"]) if count > 0: self.collection.delete(ids=existing["ids"]) logger.info("Deleted %d chunks for document_id=%s", count, document_id) return count def list_documents(self) -> list[dict]: all_items = self.collection.get(include=["metadatas"]) docs: dict[str, dict] = {} for meta in all_items["metadatas"]: doc_id = meta["document_id"] if doc_id not in docs: docs[doc_id] = { "document_id": doc_id, "name": meta.get("document_name", ""), "file_type": meta.get("file_type", ""), "created_at": meta.get("created_at", ""), "chunks_count": 0, "metadata": {}, } docs[doc_id]["chunks_count"] += 1 return list(docs.values()) def get_document_chunks(self, document_id: str) -> list[dict]: """Return all chunks for a document, sorted by chunk_index.""" results = self.collection.get( where={"document_id": document_id}, include=["documents", "metadatas"], ) items = [] if results["ids"]: for i, chunk_id in enumerate(results["ids"]): items.append({ "chunk_id": chunk_id, "text": results["documents"][i], "metadata": results["metadatas"][i], }) items.sort(key=lambda x: x["metadata"].get("chunk_index", 0)) return items def get_stats(self) -> dict: all_items = self.collection.get(include=["metadatas"]) doc_ids = set() for meta in all_items["metadatas"]: doc_ids.add(meta.get("document_id", "")) return { "documents_count": len(doc_ids), "chunks_count": self.collection.count(), }