From 772186b6e5461e73045df87ab4cc7287b4ef35e6 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Wed, 4 Sep 2024 21:22:50 +0800 Subject: [PATCH] feat: support milvus vector db (#188) #none Signed-off-by: ChengZi --- flowsettings.py | 1 + libs/kotaemon/kotaemon/storages/__init__.py | 2 + .../storages/vectorstores/__init__.py | 2 + .../kotaemon/storages/vectorstores/milvus.py | 100 ++++++++++++++++++ libs/kotaemon/pyproject.toml | 1 + libs/kotaemon/tests/test_vectorstore.py | 95 +++++++++++++++++ 6 files changed, 201 insertions(+) create mode 100644 libs/kotaemon/kotaemon/storages/vectorstores/milvus.py diff --git a/flowsettings.py b/flowsettings.py index cae5c68..99881a1 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -80,6 +80,7 @@ KH_DOCSTORE = { KH_VECTORSTORE = { # "__type__": "kotaemon.storages.LanceDBVectorStore", "__type__": "kotaemon.storages.ChromaVectorStore", + # "__type__": "kotaemon.storages.MilvusVectorStore", "path": str(KH_USER_DATA_DIR / "vectorstore"), } KH_LLMS = {} diff --git a/libs/kotaemon/kotaemon/storages/__init__.py b/libs/kotaemon/kotaemon/storages/__init__.py index d5f5c94..3fac634 100644 --- a/libs/kotaemon/kotaemon/storages/__init__.py +++ b/libs/kotaemon/kotaemon/storages/__init__.py @@ -10,6 +10,7 @@ from .vectorstores import ( ChromaVectorStore, InMemoryVectorStore, LanceDBVectorStore, + MilvusVectorStore, SimpleFileVectorStore, ) @@ -26,4 +27,5 @@ __all__ = [ "InMemoryVectorStore", "SimpleFileVectorStore", "LanceDBVectorStore", + "MilvusVectorStore", ] diff --git a/libs/kotaemon/kotaemon/storages/vectorstores/__init__.py b/libs/kotaemon/kotaemon/storages/vectorstores/__init__.py index befc1e8..2087ee6 100644 --- a/libs/kotaemon/kotaemon/storages/vectorstores/__init__.py +++ b/libs/kotaemon/kotaemon/storages/vectorstores/__init__.py @@ -2,6 +2,7 @@ from .base import BaseVectorStore from .chroma import ChromaVectorStore from .in_memory import InMemoryVectorStore from .lancedb import LanceDBVectorStore +from .milvus import MilvusVectorStore from .simple_file import SimpleFileVectorStore __all__ = [ @@ -10,4 +11,5 @@ __all__ = [ "InMemoryVectorStore", "SimpleFileVectorStore", "LanceDBVectorStore", + "MilvusVectorStore", ] diff --git a/libs/kotaemon/kotaemon/storages/vectorstores/milvus.py b/libs/kotaemon/kotaemon/storages/vectorstores/milvus.py new file mode 100644 index 0000000..974200d --- /dev/null +++ b/libs/kotaemon/kotaemon/storages/vectorstores/milvus.py @@ -0,0 +1,100 @@ +import os +from typing import Any, Optional, Type, cast + +from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore + +from kotaemon.base import DocumentWithEmbedding + +from .base import LlamaIndexVectorStore + + +class MilvusVectorStore(LlamaIndexVectorStore): + _li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore + + def __init__( + self, + uri: str = "./milvus.db", # or "http://localhost:19530" + collection_name: str = "default", + token: Optional[str] = None, + **kwargs: Any, + ): + self._uri = uri + self._collection_name = collection_name + self._token = token + self._kwargs = kwargs + self._path = kwargs.get("path", None) + self._inited = False + + def _lazy_init(self, dim: Optional[int] = None): + """ + Lazy init the client. + Because the LlamaIndex init method requires the dim parameter, + we need to try to get the dim from the first embedding. + + Args: + dim: Dimension of the vectors. + """ + if not self._inited: + if os.path.isdir(self._path) and not self._uri.startswith("http"): + uri = os.path.join(self._path, self._uri) + else: + uri = self._uri + super().__init__( + uri=uri, + token=self._token, + collection_name=self._collection_name, + dim=dim, + **self._kwargs, + ) + self._client = cast(LIMilvusVectorStore, self._client) + self._inited = True + + def add( + self, + embeddings: list[list[float]] | list[DocumentWithEmbedding], + metadatas: Optional[list[dict]] = None, + ids: Optional[list[str]] = None, + ): + if not self._inited: + if isinstance(embeddings[0], list): + dim = len(embeddings[0]) + else: + dim = len(embeddings[0].embedding) + self._lazy_init(dim) + + return super().add(embeddings=embeddings, metadatas=metadatas, ids=ids) + + def query( + self, + embedding: list[float], + top_k: int = 1, + ids: Optional[list[str]] = None, + **kwargs, + ) -> tuple[list[list[float]], list[float], list[str]]: + self._lazy_init(len(embedding)) + + return super().query(embedding=embedding, top_k=top_k, ids=ids, **kwargs) + + def delete(self, ids: list[str], **kwargs): + self._lazy_init() + super().delete(ids=ids, **kwargs) + + def drop(self): + self._client.client.drop_collection(self._collection_name) + + def count(self) -> int: + try: + self._lazy_init() + except: # noqa: E722 + return 0 + return self._client.client.query( + collection_name=self._collection_name, output_fields=["count(*)"] + )[0]["count(*)"] + + def __persist_flow__(self): + return { + "uri": self._uri, + "collection_name": self._collection_name, + "token": self._token, + **self._kwargs, + } diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index 8727bc0..e3299b7 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "llama-index>=0.10.40,<0.11.0", "llama-index-vector-stores-chroma>=0.1.9", "llama-index-vector-stores-lancedb", + "llama-index-vector-stores-milvus", "openai>=1.23.6,<2", "openpyxl>=3.1.2,<3.2", "pandas>=2.2.2,<2.3", diff --git a/libs/kotaemon/tests/test_vectorstore.py b/libs/kotaemon/tests/test_vectorstore.py index 57e033d..edf4990 100644 --- a/libs/kotaemon/tests/test_vectorstore.py +++ b/libs/kotaemon/tests/test_vectorstore.py @@ -5,6 +5,7 @@ from kotaemon.base import DocumentWithEmbedding from kotaemon.storages import ( ChromaVectorStore, InMemoryVectorStore, + MilvusVectorStore, SimpleFileVectorStore, ) @@ -153,3 +154,97 @@ class TestSimpleFileVectorStore: ], "load function does not load data completely" os.remove(tmp_path / collection_name) + + +class TestMilvusVectorStore: + def test_add(self, tmp_path): + """Test that the DB add correctly""" + db = MilvusVectorStore( + path=str(tmp_path), + overwrite=True, + ) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + ids = ["1", "2"] + + assert db.count() == 0, "Expected empty collection" + output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) + assert output == ids, "Expected output to be the same as ids" + assert db.count() == 2, "Expected 2 added entries" + + def test_add_from_docs(self, tmp_path): + db = MilvusVectorStore( + path=str(tmp_path), + overwrite=True, + ) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + documents = [ + DocumentWithEmbedding(embedding=embedding, metadata=metadata) + for embedding, metadata in zip(embeddings, metadatas) + ] + assert db.count() == 0, "Expected empty collection" + output = db.add(documents) + assert len(output) == 2, "Expected outputting 2 ids" + assert db.count() == 2, "Expected 2 added entries" + + def test_delete(self, tmp_path): + db = MilvusVectorStore( + path=str(tmp_path), + overwrite=True, + ) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] + ids = ["a", "b", "c"] + + db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) + assert db.count() == 3, "Expected 3 added entries" + db.delete(ids=["a", "b"]) + assert db.count() == 1, "Expected 1 remaining entry" + db.delete(ids=["c"]) + assert db.count() == 0, "Expected 0 remaining entry" + + def test_query(self, tmp_path): + db = MilvusVectorStore(path=str(tmp_path), overwrite=True) + import numpy as np + + embeddings = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + norms = np.linalg.norm(embeddings, axis=1) + normalized_embeddings = (embeddings / norms[:, np.newaxis]).tolist() + + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] + ids = ["a", "b", "c"] + + db.add(embeddings=normalized_embeddings, metadatas=metadatas, ids=ids) + + _, sim, out_ids = db.query(embedding=normalized_embeddings[0], top_k=1) + assert sim == [1.0] + assert out_ids == ["a"] + + query_embedding = [ + normalized_embeddings[1][0] + 0.02, + normalized_embeddings[1][1] + 0.02, + normalized_embeddings[1][2] + 0.02, + ] + _, _, out_ids = db.query(embedding=query_embedding, top_k=1) + assert out_ids == ["b"] + + def test_save_load_delete(self, tmp_path): + """Test that save/load func behave correctly.""" + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] + ids = ["1", "2", "3"] + db = MilvusVectorStore(path=str(tmp_path), overwrite=True) + db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) + + db2 = MilvusVectorStore(path=str(tmp_path), overrides=False) + assert db2.count() == 3, "load function does not load data completely" + + # test delete collection function + db2.drop() + # reinit the milvus with the same collection name + db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False) + assert db2.count() == 0, "delete collection function does not work correctly"