feat: Qdrant vectorstore support (#260)
* feat: Qdrant vectorstore support * chore: review changes * docs: Updated README.md
This commit is contained in:
parent
cbe45a4395
commit
e2bd78e9c4
|
@ -189,7 +189,7 @@ starting point.
|
|||
KH_DOCSTORE=(Elasticsearch | LanceDB | SimpleFileDocumentStore)
|
||||
|
||||
# setup your preferred vectorstore (for vector-based search)
|
||||
KH_VECTORSTORE=(ChromaDB | LanceDB | InMemory)
|
||||
KH_VECTORSTORE=(ChromaDB | LanceDB | InMemory | Qdrant)
|
||||
|
||||
# Enable / disable multimodal QA
|
||||
KH_REASONINGS_USE_MULTIMODAL=True
|
||||
|
|
|
@ -81,6 +81,7 @@ KH_VECTORSTORE = {
|
|||
# "__type__": "kotaemon.storages.LanceDBVectorStore",
|
||||
"__type__": "kotaemon.storages.ChromaVectorStore",
|
||||
# "__type__": "kotaemon.storages.MilvusVectorStore",
|
||||
# "__type__": "kotaemon.storages.QdrantVectorStore",
|
||||
"path": str(KH_USER_DATA_DIR / "vectorstore"),
|
||||
}
|
||||
KH_LLMS = {}
|
||||
|
|
|
@ -11,6 +11,7 @@ from .vectorstores import (
|
|||
InMemoryVectorStore,
|
||||
LanceDBVectorStore,
|
||||
MilvusVectorStore,
|
||||
QdrantVectorStore,
|
||||
SimpleFileVectorStore,
|
||||
)
|
||||
|
||||
|
@ -28,4 +29,5 @@ __all__ = [
|
|||
"SimpleFileVectorStore",
|
||||
"LanceDBVectorStore",
|
||||
"MilvusVectorStore",
|
||||
"QdrantVectorStore",
|
||||
]
|
||||
|
|
|
@ -3,6 +3,7 @@ from .chroma import ChromaVectorStore
|
|||
from .in_memory import InMemoryVectorStore
|
||||
from .lancedb import LanceDBVectorStore
|
||||
from .milvus import MilvusVectorStore
|
||||
from .qdrant import QdrantVectorStore
|
||||
from .simple_file import SimpleFileVectorStore
|
||||
|
||||
__all__ = [
|
||||
|
@ -12,4 +13,5 @@ __all__ = [
|
|||
"SimpleFileVectorStore",
|
||||
"LanceDBVectorStore",
|
||||
"MilvusVectorStore",
|
||||
"QdrantVectorStore",
|
||||
]
|
||||
|
|
67
libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py
Normal file
67
libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
|
||||
class QdrantVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name,
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
client_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._collection_name = collection_name
|
||||
self._url = url
|
||||
self._api_key = api_key
|
||||
self._client_kwargs = client_kwargs
|
||||
self._kwargs = kwargs
|
||||
|
||||
super().__init__(
|
||||
collection_name=collection_name,
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
self._client = cast(LIQdrantVectorStore, self._client)
|
||||
|
||||
def delete(self, ids: List[str], **kwargs):
|
||||
"""Delete vector embeddings from vector stores
|
||||
|
||||
Args:
|
||||
ids: List of ids of the embeddings to be deleted
|
||||
kwargs: meant for vectorstore-specific parameters
|
||||
"""
|
||||
from qdrant_client import models
|
||||
|
||||
self._client.client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=models.PointIdsList(
|
||||
points=ids,
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def drop(self):
|
||||
"""Delete entire collection from vector stores"""
|
||||
self._client.client.delete_collection(self._collection_name)
|
||||
|
||||
def count(self) -> int:
|
||||
return self._client.client.count(
|
||||
collection_name=self._collection_name, exact=True
|
||||
).count
|
||||
|
||||
def __persist_flow__(self):
|
||||
return {
|
||||
"collection_name": self._collection_name,
|
||||
"url": self._url,
|
||||
"api_key": self._api_key,
|
||||
"client_kwargs": self._client_kwargs,
|
||||
**self._kwargs,
|
||||
}
|
|
@ -22,7 +22,7 @@ requires-python = ">= 3.10"
|
|||
description = "Kotaemon core library for AI development."
|
||||
dependencies = [
|
||||
"click>=8.1.7,<9",
|
||||
"cohere>=5.3.2,<5.4",
|
||||
"cohere>=5.3.2,<6",
|
||||
"cookiecutter>=2.6.0,<2.7",
|
||||
"fast_langdetect",
|
||||
"gradio>=4.31.0,<4.40",
|
||||
|
@ -73,6 +73,7 @@ adv = [
|
|||
"sentence-transformers",
|
||||
"llama-cpp-python<0.2.8",
|
||||
"fastembed",
|
||||
"llama-index-vector-stores-qdrant",
|
||||
]
|
||||
dev = [
|
||||
"black",
|
||||
|
|
|
@ -51,6 +51,15 @@ def if_unstructured_not_installed():
|
|||
return False
|
||||
|
||||
|
||||
def if_cohere_not_installed():
|
||||
try:
|
||||
import cohere # noqa: F401
|
||||
except ImportError:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def if_llama_cpp_not_installed():
|
||||
try:
|
||||
import llama_cpp # noqa: F401
|
||||
|
@ -76,6 +85,10 @@ skip_when_unstructured_not_installed = pytest.mark.skipif(
|
|||
if_unstructured_not_installed(), reason="unstructured is not installed"
|
||||
)
|
||||
|
||||
skip_when_cohere_not_installed = pytest.mark.skipif(
|
||||
if_cohere_not_installed(), reason="cohere is not installed"
|
||||
)
|
||||
|
||||
skip_openai_lc_wrapper_test = pytest.mark.skipif(
|
||||
True, reason="OpenAI LC wrapper test is skipped"
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ from kotaemon.embeddings import (
|
|||
)
|
||||
|
||||
from .conftest import (
|
||||
skip_when_cohere_not_installed,
|
||||
skip_when_fastembed_not_installed,
|
||||
skip_when_sentence_bert_not_installed,
|
||||
)
|
||||
|
@ -132,6 +133,7 @@ def test_lchuggingface_embeddings(
|
|||
langchain_huggingface_embedding_call.assert_called()
|
||||
|
||||
|
||||
@skip_when_cohere_not_installed
|
||||
@patch(
|
||||
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
|
||||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
from kotaemon.storages import (
|
||||
ChromaVectorStore,
|
||||
InMemoryVectorStore,
|
||||
MilvusVectorStore,
|
||||
QdrantVectorStore,
|
||||
SimpleFileVectorStore,
|
||||
)
|
||||
|
||||
|
@ -248,3 +251,118 @@ class TestMilvusVectorStore:
|
|||
# reinit the milvus with the same collection name
|
||||
db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False)
|
||||
assert db2.count() == 0, "delete collection function does not work correctly"
|
||||
|
||||
|
||||
class TestQdrantVectorStore:
|
||||
def test_add(self):
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||
ids = [
|
||||
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||
]
|
||||
|
||||
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
assert output == ids, "Expected output to be the same as ids"
|
||||
assert db.count() == 2, "Expected 2 added entries"
|
||||
|
||||
def test_add_from_docs(self, tmp_path):
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||
documents = [
|
||||
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
|
||||
for embedding, metadata in zip(embeddings, metadatas)
|
||||
]
|
||||
|
||||
output = db.add(documents)
|
||||
assert len(output) == 2, "Expected outputting 2 ids"
|
||||
assert db.count() == 2, "Expected 2 added entries"
|
||||
|
||||
def test_delete(self, tmp_path):
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||
ids = [
|
||||
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||
"6bed07c3-d284-47a3-a711-c3f9186755b8",
|
||||
]
|
||||
|
||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
assert db.count() == 3, "Expected 3 added entries"
|
||||
db.delete(
|
||||
ids=[
|
||||
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||
]
|
||||
)
|
||||
assert db.count() == 1, "Expected 1 remaining entry"
|
||||
db.delete(ids=["6bed07c3-d284-47a3-a711-c3f9186755b8"])
|
||||
assert db.count() == 0, "Expected 0 remaining entry"
|
||||
|
||||
def test_query(self, tmp_path):
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||
ids = [
|
||||
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||
"6bed07c3-d284-47a3-a711-c3f9186755b8",
|
||||
]
|
||||
|
||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
_, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1)
|
||||
assert sim[0] - 1.0 < 1e-6
|
||||
assert out_ids == ["0f0611b3-2d9c-4818-ab69-1f1c4cf66693"]
|
||||
|
||||
_, _, out_ids = db.query(embedding=[0.4, 0.5, 0.6], top_k=1)
|
||||
assert out_ids == ["90aba5d3-f4f8-47c6-bad9-5ea457442e07"]
|
||||
|
||||
def test_save_load_delete(self, tmp_path):
|
||||
"""Test that save/load func behave correctly."""
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||
ids = [
|
||||
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||
"6bed07c3-d284-47a3-a711-c3f9186755b8",
|
||||
]
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
db = QdrantVectorStore(
|
||||
collection_name="test", client=QdrantClient(path=tmp_path)
|
||||
)
|
||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
del db
|
||||
|
||||
db2 = QdrantVectorStore(
|
||||
collection_name="test", client=QdrantClient(path=tmp_path)
|
||||
)
|
||||
assert db2.count() == 3
|
||||
|
||||
db2.drop()
|
||||
del db2
|
||||
|
||||
db2 = QdrantVectorStore(
|
||||
collection_name="test", client=QdrantClient(path=tmp_path)
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# Since no docs were added, the collection should not exist yet
|
||||
# and thus the count function should raise an exception
|
||||
db2.count()
|
||||
|
|
Loading…
Reference in New Issue
Block a user