[AUR-392, AUR-413, AUR-414] Define base vector store, and make use of ChromaVectorStore from llama_index. Indexing and retrieving vectors with vector store (#18)
Design the base interface of vector store, and apply it to the Chroma Vector Store (wrapped around llama_index's implementation). Provide the pipelines to populate and retrieve from vector store.
This commit is contained in:
parent
c339912312
commit
620b2b03ca
|
@ -0,0 +1,4 @@
|
||||||
|
from .base import BaseEmbeddings
|
||||||
|
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
|
|
||||||
|
__all__ = ["BaseEmbeddings", "OpenAIEmbeddings", "AzureOpenAIEmbeddings"]
|
|
@ -1,3 +1,4 @@
|
||||||
|
from abc import abstractmethod
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings as LCEmbeddings
|
from langchain.embeddings.base import Embeddings as LCEmbeddings
|
||||||
|
@ -7,11 +8,37 @@ from ..components import BaseComponent
|
||||||
from ..documents.base import Document
|
from ..documents.base import Document
|
||||||
|
|
||||||
|
|
||||||
class Embeddings(BaseComponent):
|
class BaseEmbeddings(BaseComponent):
|
||||||
...
|
@abstractmethod
|
||||||
|
def run_raw(self, text: str) -> List[float]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_batch_raw(self, text: List[str]) -> List[List[float]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_document(self, text: Document) -> List[float]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_batch_document(self, text: List[Document]) -> List[List[float]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def is_document(self, text) -> bool:
|
||||||
|
if isinstance(text, Document):
|
||||||
|
return True
|
||||||
|
elif isinstance(text, List) and isinstance(text[0], Document):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_batch(self, text) -> bool:
|
||||||
|
if isinstance(text, list):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LangchainEmbeddings(Embeddings):
|
class LangchainEmbeddings(BaseEmbeddings):
|
||||||
_lc_class: Type[LCEmbeddings]
|
_lc_class: Type[LCEmbeddings]
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
@ -46,17 +73,5 @@ class LangchainEmbeddings(Embeddings):
|
||||||
def run_document(self, text: Document) -> List[float]:
|
def run_document(self, text: Document) -> List[float]:
|
||||||
return self.agent.embed_query(text.text) # type: ignore
|
return self.agent.embed_query(text.text) # type: ignore
|
||||||
|
|
||||||
def run_batch_document(self, text: List[Document]):
|
def run_batch_document(self, text: List[Document]) -> List[List[float]]:
|
||||||
return self.agent.embed_documents([each.text for each in text]) # type: ignore
|
return self.agent.embed_documents([each.text for each in text]) # type: ignore
|
||||||
|
|
||||||
def is_document(self, text) -> bool:
|
|
||||||
if isinstance(text, Document):
|
|
||||||
return True
|
|
||||||
elif isinstance(text, List) and isinstance(text[0], Document):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_batch(self, text) -> bool:
|
|
||||||
if isinstance(text, list):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
58
knowledgehub/pipelines/indexing.py
Normal file
58
knowledgehub/pipelines/indexing.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from theflow import Node, Param
|
||||||
|
|
||||||
|
from ..components import BaseComponent
|
||||||
|
from ..documents.base import Document
|
||||||
|
from ..embeddings import BaseEmbeddings
|
||||||
|
from ..vectorstores import BaseVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||||
|
"""Ingest the document, run through the embedding, and store the embedding in a
|
||||||
|
vector store.
|
||||||
|
|
||||||
|
This pipeline supports the following set of inputs:
|
||||||
|
- List of documents
|
||||||
|
- List of texts
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_store: Param[BaseVectorStore] = Param()
|
||||||
|
embedding: Node[BaseEmbeddings] = Node()
|
||||||
|
# TODO: populate to document store as well when it's finished
|
||||||
|
# TODO: refer to llama_index's storage as well
|
||||||
|
|
||||||
|
def run_raw(self, text: str) -> None:
|
||||||
|
self.vector_store.add([self.embedding(text)])
|
||||||
|
|
||||||
|
def run_batch_raw(self, text: List[str]) -> None:
|
||||||
|
self.vector_store.add(self.embedding(text))
|
||||||
|
|
||||||
|
def run_document(self, text: Document) -> None:
|
||||||
|
self.vector_store.add([self.embedding(text)])
|
||||||
|
|
||||||
|
def run_batch_document(self, text: List[Document]) -> None:
|
||||||
|
self.vector_store.add(self.embedding(text))
|
||||||
|
|
||||||
|
def is_document(self, text) -> bool:
|
||||||
|
if isinstance(text, Document):
|
||||||
|
return True
|
||||||
|
elif isinstance(text, List) and isinstance(text[0], Document):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_batch(self, text) -> bool:
|
||||||
|
if isinstance(text, list):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def persist(self, path: str):
|
||||||
|
"""Save the whole state of the indexing pipeline vector store and all
|
||||||
|
necessary information to disk
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): path to save the state
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load(self, path: str):
|
||||||
|
"""Load all information from disk to an object"""
|
58
knowledgehub/pipelines/retrieving.py
Normal file
58
knowledgehub/pipelines/retrieving.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from theflow import Node, Param
|
||||||
|
|
||||||
|
from ..components import BaseComponent
|
||||||
|
from ..documents.base import Document
|
||||||
|
from ..embeddings import BaseEmbeddings
|
||||||
|
from ..vectorstores import BaseVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
|
||||||
|
"""Retrieve list of documents from vector store"""
|
||||||
|
|
||||||
|
vector_store: Param[BaseVectorStore] = Param()
|
||||||
|
embedding: Node[BaseEmbeddings] = Node()
|
||||||
|
# TODO: populate to document store as well when it's finished
|
||||||
|
# TODO: refer to llama_index's storage as well
|
||||||
|
|
||||||
|
def run_raw(self, text: str) -> List[str]:
|
||||||
|
emb = self.embedding(text)
|
||||||
|
return self.vector_store.query(embedding=emb)[2]
|
||||||
|
|
||||||
|
def run_batch_raw(self, text: List[str]) -> List[List[str]]:
|
||||||
|
result = []
|
||||||
|
for each_text in text:
|
||||||
|
emb = self.embedding(each_text)
|
||||||
|
result.append(self.vector_store.query(embedding=emb)[2])
|
||||||
|
return result
|
||||||
|
|
||||||
|
def run_document(self, text: Document) -> List[str]:
|
||||||
|
return self.run_raw(text.text)
|
||||||
|
|
||||||
|
def run_batch_document(self, text: List[Document]) -> List[List[str]]:
|
||||||
|
input_text = [each.text for each in text]
|
||||||
|
return self.run_batch_raw(input_text)
|
||||||
|
|
||||||
|
def is_document(self, text) -> bool:
|
||||||
|
if isinstance(text, Document):
|
||||||
|
return True
|
||||||
|
elif isinstance(text, List) and isinstance(text[0], Document):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_batch(self, text) -> bool:
|
||||||
|
if isinstance(text, list):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def persist(self, path: str):
|
||||||
|
"""Save the whole state of the indexing pipeline vector store and all
|
||||||
|
necessary information to disk
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): path to save the state
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load(self, path: str):
|
||||||
|
"""Load all information from disk to an object"""
|
4
knowledgehub/vectorstores/__init__.py
Normal file
4
knowledgehub/vectorstores/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
from .base import BaseVectorStore
|
||||||
|
from .chroma import ChromaVectorStore
|
||||||
|
|
||||||
|
__all__ = ["BaseVectorStore", "ChromaVectorStore"]
|
154
knowledgehub/vectorstores/base.py
Normal file
154
knowledgehub/vectorstores/base.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from llama_index.vector_stores.types import BasePydanticVectorStore
|
||||||
|
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
||||||
|
from llama_index.vector_stores.types import VectorStoreQuery
|
||||||
|
|
||||||
|
from ..documents.base import Document
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVectorStore(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Add vector embeddings to vector stores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: List of embeddings
|
||||||
|
metadatas: List of metadata of the embeddings
|
||||||
|
ids: List of ids of the embeddings
|
||||||
|
kwargs: meant for vectorstore-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ids of the embeddings
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_from_docs(self, docs: List[Document]):
|
||||||
|
"""Add vector embeddings to vector stores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of Document objects
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, ids: List[str], **kwargs):
|
||||||
|
"""Delete vector embeddings from vector stores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids of the embeddings to be deleted
|
||||||
|
kwargs: meant for vectorstore-specific parameters
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def update(self, *args, **kwargs):
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def persist(self, *args, **kwargs):
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# @classmethod
|
||||||
|
# @abstractmethod
|
||||||
|
# def load(self, *args, **kwargs):
|
||||||
|
# ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
top_k: int = 1,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
) -> Tuple[List[List[float]], List[float], List[str]]:
|
||||||
|
"""Return the top k most similar vector embeddings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: List of embeddings
|
||||||
|
top_k: Number of most similar embeddings to return
|
||||||
|
ids: List of ids of the embeddings to be queried
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the matched embeddings, the similarity scores, and the ids
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaIndexVectorStore(BaseVectorStore):
|
||||||
|
_li_class: Type[Union[LIVectorStore, BasePydanticVectorStore]]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if self._li_class is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"Require `_li_class` to set a VectorStore class from LlamarIndex"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client = self._li_class(*args, **kwargs)
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
if name.startswith("_"):
|
||||||
|
return super().__setattr__(name, value)
|
||||||
|
|
||||||
|
return setattr(self._client, name, value)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
return getattr(self._client, name)
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
nodes = [Document(embedding=embedding) for embedding in embeddings]
|
||||||
|
if metadatas is not None:
|
||||||
|
for node, metadata in zip(nodes, metadatas):
|
||||||
|
node.metadata = metadata
|
||||||
|
if ids is not None:
|
||||||
|
for node, id in zip(nodes, ids):
|
||||||
|
node.id_ = id
|
||||||
|
|
||||||
|
return self._client.add(nodes=nodes) # type: ignore
|
||||||
|
|
||||||
|
def add_from_docs(self, docs: List[Document]):
|
||||||
|
return self._client.add(nodes=docs) # type: ignore
|
||||||
|
|
||||||
|
def delete(self, ids: List[str], **kwargs):
|
||||||
|
for id_ in ids:
|
||||||
|
self._client.delete(ref_doc_id=id_, **kwargs)
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
top_k: int = 1,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[List[List[float]], List[float], List[str]]:
|
||||||
|
output = self._client.query(
|
||||||
|
query=VectorStoreQuery(
|
||||||
|
query_embedding=embedding,
|
||||||
|
similarity_top_k=top_k,
|
||||||
|
node_ids=ids,
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
if output.nodes:
|
||||||
|
for node in output.nodes:
|
||||||
|
embeddings.append(node.embedding)
|
||||||
|
similarities = output.similarities if output.similarities else []
|
||||||
|
out_ids = output.ids if output.ids else []
|
||||||
|
|
||||||
|
return embeddings, similarities, out_ids
|
56
knowledgehub/vectorstores/chroma.py
Normal file
56
knowledgehub/vectorstores/chroma.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
from typing import Any, Dict, List, Optional, Type, cast
|
||||||
|
|
||||||
|
from llama_index.vector_stores.chroma import ChromaVectorStore as LIChromaVectorStore
|
||||||
|
|
||||||
|
from .base import LlamaIndexVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaVectorStore(LlamaIndexVectorStore):
|
||||||
|
_li_class: Type[LIChromaVectorStore] = LIChromaVectorStore
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str = "./chroma",
|
||||||
|
collection_name: str = "default",
|
||||||
|
host: str = "localhost",
|
||||||
|
port: str = "8000",
|
||||||
|
ssl: bool = False,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
collection_kwargs: Optional[dict] = None,
|
||||||
|
stores_text: bool = True,
|
||||||
|
flat_metadata: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaVectorStore requires chromadb. "
|
||||||
|
"Please install chromadb first `pip install chromadb`"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = chromadb.PersistentClient(path=path)
|
||||||
|
collection = client.get_or_create_collection(collection_name)
|
||||||
|
|
||||||
|
# pass through for nice IDE support
|
||||||
|
super().__init__(
|
||||||
|
chroma_collection=collection,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
ssl=ssl,
|
||||||
|
headers=headers or {},
|
||||||
|
collection_kwargs=collection_kwargs or {},
|
||||||
|
stores_text=stores_text,
|
||||||
|
flat_metadata=flat_metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self._client = cast(LIChromaVectorStore, self._client)
|
||||||
|
|
||||||
|
def delete(self, ids: List[str], **kwargs):
|
||||||
|
"""Delete vector embeddings from vector stores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids of the embeddings to be deleted
|
||||||
|
kwargs: meant for vectorstore-specific parameters
|
||||||
|
"""
|
||||||
|
self._client._collection.delete(ids=ids)
|
1
setup.py
1
setup.py
|
@ -46,6 +46,7 @@ setuptools.setup(
|
||||||
"coverage",
|
"coverage",
|
||||||
# optional dependency needed for test
|
# optional dependency needed for test
|
||||||
"openai",
|
"openai",
|
||||||
|
"chromadb",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},
|
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},
|
||||||
|
|
59
tests/test_indexing_retrieval.py
Normal file
59
tests/test_indexing_retrieval.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai.api_resources.embedding import Embedding
|
||||||
|
|
||||||
|
from kotaemon.documents.base import Document
|
||||||
|
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
||||||
|
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||||
|
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||||
|
from kotaemon.vectorstores import ChromaVectorStore
|
||||||
|
|
||||||
|
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
||||||
|
openai_embedding = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def mock_openai_embedding(monkeypatch):
|
||||||
|
monkeypatch.setattr(Embedding, "create", lambda *args, **kwargs: openai_embedding)
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexing(mock_openai_embedding, tmp_path):
|
||||||
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
embedding = AzureOpenAIEmbeddings(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
deployment="embedding-deployment",
|
||||||
|
openai_api_base="https://test.openai.azure.com/",
|
||||||
|
openai_api_key="some-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = IndexVectorStoreFromDocumentPipeline(
|
||||||
|
vector_store=db, embedding=embedding
|
||||||
|
)
|
||||||
|
assert pipeline.vector_store._collection.count() == 0, "Expected empty collection"
|
||||||
|
pipeline(text=Document(text="Hello world"))
|
||||||
|
assert pipeline.vector_store._collection.count() == 1, "Index 1 item"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieving(mock_openai_embedding, tmp_path):
|
||||||
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
embedding = AzureOpenAIEmbeddings(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
deployment="embedding-deployment",
|
||||||
|
openai_api_base="https://test.openai.azure.com/",
|
||||||
|
openai_api_key="some-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
index_pipeline = IndexVectorStoreFromDocumentPipeline(
|
||||||
|
vector_store=db, embedding=embedding
|
||||||
|
)
|
||||||
|
retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||||
|
vector_store=db, embedding=embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
index_pipeline(text=Document(text="Hello world"))
|
||||||
|
output = retrieval_pipeline(text=["Hello world", "Hello world"])
|
||||||
|
|
||||||
|
assert len(output) == 2, "Expected 2 results"
|
||||||
|
assert output[0] == output[1], "Expected identical results"
|
61
tests/test_vectorstore.py
Normal file
61
tests/test_vectorstore.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
from kotaemon.documents.base import Document
|
||||||
|
from kotaemon.vectorstores import ChromaVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaVectorStore:
|
||||||
|
def test_add(self, tmp_path):
|
||||||
|
"""Test that the DB add correctly"""
|
||||||
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||||
|
ids = ["1", "2"]
|
||||||
|
|
||||||
|
assert db._collection.count() == 0, "Expected empty collection"
|
||||||
|
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
assert output == ids, "Expected output to be the same as ids"
|
||||||
|
assert db._collection.count() == 2, "Expected 2 added entries"
|
||||||
|
|
||||||
|
def test_add_from_docs(self, tmp_path):
|
||||||
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||||
|
documents = [
|
||||||
|
Document(embedding=embedding, metadata=metadata)
|
||||||
|
for embedding, metadata in zip(embeddings, metadatas)
|
||||||
|
]
|
||||||
|
assert db._collection.count() == 0, "Expected empty collection"
|
||||||
|
output = db.add_from_docs(documents)
|
||||||
|
assert len(output) == 2, "Expected outputing 2 ids"
|
||||||
|
assert db._collection.count() == 2, "Expected 2 added entries"
|
||||||
|
|
||||||
|
def test_delete(self, tmp_path):
|
||||||
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = ["a", "b", "c"]
|
||||||
|
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
assert db._collection.count() == 3, "Expected 3 added entries"
|
||||||
|
db.delete(ids=["a", "b"])
|
||||||
|
assert db._collection.count() == 1, "Expected 1 remaining entry"
|
||||||
|
db.delete(ids=["c"])
|
||||||
|
assert db._collection.count() == 0, "Expected 0 remaining entry"
|
||||||
|
|
||||||
|
def test_query(self, tmp_path):
|
||||||
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = ["a", "b", "c"]
|
||||||
|
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
|
_, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1)
|
||||||
|
assert sim == [0.0]
|
||||||
|
assert out_ids == ["a"]
|
||||||
|
|
||||||
|
_, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)
|
||||||
|
assert out_ids == ["b"]
|
Loading…
Reference in New Issue
Block a user