From e2bd78e9c44ad86acbf06ea8c8b243e22fb75cbe Mon Sep 17 00:00:00 2001 From: Anush Date: Mon, 16 Sep 2024 02:47:36 +0530 Subject: [PATCH] feat: Qdrant vectorstore support (#260) * feat: Qdrant vectorstore support * chore: review changes * docs: Updated README.md --- README.md | 2 +- flowsettings.py | 1 + libs/kotaemon/kotaemon/storages/__init__.py | 2 + .../storages/vectorstores/__init__.py | 2 + .../kotaemon/storages/vectorstores/qdrant.py | 67 ++++++++++ libs/kotaemon/pyproject.toml | 3 +- libs/kotaemon/tests/conftest.py | 13 ++ libs/kotaemon/tests/test_embedding_models.py | 2 + libs/kotaemon/tests/test_vectorstore.py | 118 ++++++++++++++++++ 9 files changed, 208 insertions(+), 2 deletions(-) create mode 100644 libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py diff --git a/README.md b/README.md index d7d58c1..922d50a 100644 --- a/README.md +++ b/README.md @@ -189,7 +189,7 @@ starting point. KH_DOCSTORE=(Elasticsearch | LanceDB | SimpleFileDocumentStore) # setup your preferred vectorstore (for vector-based search) -KH_VECTORSTORE=(ChromaDB | LanceDB | InMemory) +KH_VECTORSTORE=(ChromaDB | LanceDB | InMemory | Qdrant) # Enable / disable multimodal QA KH_REASONINGS_USE_MULTIMODAL=True diff --git a/flowsettings.py b/flowsettings.py index cd487ce..b4f9d76 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -81,6 +81,7 @@ KH_VECTORSTORE = { # "__type__": "kotaemon.storages.LanceDBVectorStore", "__type__": "kotaemon.storages.ChromaVectorStore", # "__type__": "kotaemon.storages.MilvusVectorStore", + # "__type__": "kotaemon.storages.QdrantVectorStore", "path": str(KH_USER_DATA_DIR / "vectorstore"), } KH_LLMS = {} diff --git a/libs/kotaemon/kotaemon/storages/__init__.py b/libs/kotaemon/kotaemon/storages/__init__.py index 3fac634..5638db8 100644 --- a/libs/kotaemon/kotaemon/storages/__init__.py +++ b/libs/kotaemon/kotaemon/storages/__init__.py @@ -11,6 +11,7 @@ from .vectorstores import ( InMemoryVectorStore, LanceDBVectorStore, MilvusVectorStore, + QdrantVectorStore, SimpleFileVectorStore, ) @@ -28,4 +29,5 @@ __all__ = [ "SimpleFileVectorStore", "LanceDBVectorStore", "MilvusVectorStore", + "QdrantVectorStore", ] diff --git a/libs/kotaemon/kotaemon/storages/vectorstores/__init__.py b/libs/kotaemon/kotaemon/storages/vectorstores/__init__.py index 2087ee6..1cb8541 100644 --- a/libs/kotaemon/kotaemon/storages/vectorstores/__init__.py +++ b/libs/kotaemon/kotaemon/storages/vectorstores/__init__.py @@ -3,6 +3,7 @@ from .chroma import ChromaVectorStore from .in_memory import InMemoryVectorStore from .lancedb import LanceDBVectorStore from .milvus import MilvusVectorStore +from .qdrant import QdrantVectorStore from .simple_file import SimpleFileVectorStore __all__ = [ @@ -12,4 +13,5 @@ __all__ = [ "SimpleFileVectorStore", "LanceDBVectorStore", "MilvusVectorStore", + "QdrantVectorStore", ] diff --git a/libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py b/libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py new file mode 100644 index 0000000..f3b421c --- /dev/null +++ b/libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py @@ -0,0 +1,67 @@ +from typing import Any, List, Optional, Type, cast + +from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore + +from .base import LlamaIndexVectorStore + + +class QdrantVectorStore(LlamaIndexVectorStore): + _li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore + + def __init__( + self, + collection_name, + url: Optional[str] = None, + api_key: Optional[str] = None, + client_kwargs: Optional[dict] = None, + **kwargs: Any, + ): + self._collection_name = collection_name + self._url = url + self._api_key = api_key + self._client_kwargs = client_kwargs + self._kwargs = kwargs + + super().__init__( + collection_name=collection_name, + url=url, + api_key=api_key, + client_kwargs=client_kwargs, + **kwargs, + ) + self._client = cast(LIQdrantVectorStore, self._client) + + def delete(self, ids: List[str], **kwargs): + """Delete vector embeddings from vector stores + + Args: + ids: List of ids of the embeddings to be deleted + kwargs: meant for vectorstore-specific parameters + """ + from qdrant_client import models + + self._client.client.delete( + collection_name=self._collection_name, + points_selector=models.PointIdsList( + points=ids, + ), + **kwargs, + ) + + def drop(self): + """Delete entire collection from vector stores""" + self._client.client.delete_collection(self._collection_name) + + def count(self) -> int: + return self._client.client.count( + collection_name=self._collection_name, exact=True + ).count + + def __persist_flow__(self): + return { + "collection_name": self._collection_name, + "url": self._url, + "api_key": self._api_key, + "client_kwargs": self._client_kwargs, + **self._kwargs, + } diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index dacba75..78567c7 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -22,7 +22,7 @@ requires-python = ">= 3.10" description = "Kotaemon core library for AI development." dependencies = [ "click>=8.1.7,<9", - "cohere>=5.3.2,<5.4", + "cohere>=5.3.2,<6", "cookiecutter>=2.6.0,<2.7", "fast_langdetect", "gradio>=4.31.0,<4.40", @@ -73,6 +73,7 @@ adv = [ "sentence-transformers", "llama-cpp-python<0.2.8", "fastembed", + "llama-index-vector-stores-qdrant", ] dev = [ "black", diff --git a/libs/kotaemon/tests/conftest.py b/libs/kotaemon/tests/conftest.py index c0650ae..c76114c 100644 --- a/libs/kotaemon/tests/conftest.py +++ b/libs/kotaemon/tests/conftest.py @@ -51,6 +51,15 @@ def if_unstructured_not_installed(): return False +def if_cohere_not_installed(): + try: + import cohere # noqa: F401 + except ImportError: + return True + else: + return False + + def if_llama_cpp_not_installed(): try: import llama_cpp # noqa: F401 @@ -76,6 +85,10 @@ skip_when_unstructured_not_installed = pytest.mark.skipif( if_unstructured_not_installed(), reason="unstructured is not installed" ) +skip_when_cohere_not_installed = pytest.mark.skipif( + if_cohere_not_installed(), reason="cohere is not installed" +) + skip_openai_lc_wrapper_test = pytest.mark.skipif( True, reason="OpenAI LC wrapper test is skipped" ) diff --git a/libs/kotaemon/tests/test_embedding_models.py b/libs/kotaemon/tests/test_embedding_models.py index fd13540..93d3cc5 100644 --- a/libs/kotaemon/tests/test_embedding_models.py +++ b/libs/kotaemon/tests/test_embedding_models.py @@ -14,6 +14,7 @@ from kotaemon.embeddings import ( ) from .conftest import ( + skip_when_cohere_not_installed, skip_when_fastembed_not_installed, skip_when_sentence_bert_not_installed, ) @@ -132,6 +133,7 @@ def test_lchuggingface_embeddings( langchain_huggingface_embedding_call.assert_called() +@skip_when_cohere_not_installed @patch( "langchain.embeddings.cohere.CohereEmbeddings.embed_documents", side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], diff --git a/libs/kotaemon/tests/test_vectorstore.py b/libs/kotaemon/tests/test_vectorstore.py index edf4990..2bf22ed 100644 --- a/libs/kotaemon/tests/test_vectorstore.py +++ b/libs/kotaemon/tests/test_vectorstore.py @@ -1,11 +1,14 @@ import json import os +import pytest + from kotaemon.base import DocumentWithEmbedding from kotaemon.storages import ( ChromaVectorStore, InMemoryVectorStore, MilvusVectorStore, + QdrantVectorStore, SimpleFileVectorStore, ) @@ -248,3 +251,118 @@ class TestMilvusVectorStore: # reinit the milvus with the same collection name db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False) assert db2.count() == 0, "delete collection function does not work correctly" + + +class TestQdrantVectorStore: + def test_add(self): + from qdrant_client import QdrantClient + + db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:")) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + ids = [ + "0f0611b3-2d9c-4818-ab69-1f1c4cf66693", + "90aba5d3-f4f8-47c6-bad9-5ea457442e07", + ] + + output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) + assert output == ids, "Expected output to be the same as ids" + assert db.count() == 2, "Expected 2 added entries" + + def test_add_from_docs(self, tmp_path): + from qdrant_client import QdrantClient + + db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:")) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + documents = [ + DocumentWithEmbedding(embedding=embedding, metadata=metadata) + for embedding, metadata in zip(embeddings, metadatas) + ] + + output = db.add(documents) + assert len(output) == 2, "Expected outputting 2 ids" + assert db.count() == 2, "Expected 2 added entries" + + def test_delete(self, tmp_path): + from qdrant_client import QdrantClient + + db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:")) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] + ids = [ + "0f0611b3-2d9c-4818-ab69-1f1c4cf66693", + "90aba5d3-f4f8-47c6-bad9-5ea457442e07", + "6bed07c3-d284-47a3-a711-c3f9186755b8", + ] + + db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) + assert db.count() == 3, "Expected 3 added entries" + db.delete( + ids=[ + "0f0611b3-2d9c-4818-ab69-1f1c4cf66693", + "90aba5d3-f4f8-47c6-bad9-5ea457442e07", + ] + ) + assert db.count() == 1, "Expected 1 remaining entry" + db.delete(ids=["6bed07c3-d284-47a3-a711-c3f9186755b8"]) + assert db.count() == 0, "Expected 0 remaining entry" + + def test_query(self, tmp_path): + from qdrant_client import QdrantClient + + db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:")) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] + ids = [ + "0f0611b3-2d9c-4818-ab69-1f1c4cf66693", + "90aba5d3-f4f8-47c6-bad9-5ea457442e07", + "6bed07c3-d284-47a3-a711-c3f9186755b8", + ] + + db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) + + _, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1) + assert sim[0] - 1.0 < 1e-6 + assert out_ids == ["0f0611b3-2d9c-4818-ab69-1f1c4cf66693"] + + _, _, out_ids = db.query(embedding=[0.4, 0.5, 0.6], top_k=1) + assert out_ids == ["90aba5d3-f4f8-47c6-bad9-5ea457442e07"] + + def test_save_load_delete(self, tmp_path): + """Test that save/load func behave correctly.""" + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] + ids = [ + "0f0611b3-2d9c-4818-ab69-1f1c4cf66693", + "90aba5d3-f4f8-47c6-bad9-5ea457442e07", + "6bed07c3-d284-47a3-a711-c3f9186755b8", + ] + from qdrant_client import QdrantClient + + db = QdrantVectorStore( + collection_name="test", client=QdrantClient(path=tmp_path) + ) + db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) + del db + + db2 = QdrantVectorStore( + collection_name="test", client=QdrantClient(path=tmp_path) + ) + assert db2.count() == 3 + + db2.drop() + del db2 + + db2 = QdrantVectorStore( + collection_name="test", client=QdrantClient(path=tmp_path) + ) + + with pytest.raises(Exception): + # Since no docs were added, the collection should not exist yet + # and thus the count function should raise an exception + db2.count()