Migrate the MVP into kotaemon (#108)
- Migrate the MVP into kotaemon. - Preliminary include the pipeline within chatbot interface. - Organize MVP as an application. Todo: - Add an info panel to view the planning of agents -> Fix streaming agents' output. Resolve: #60 Resolve: #61 Resolve: #62
This commit is contained in:
committed by
GitHub
parent
230328c62f
commit
5a9d6f75be
@@ -269,4 +269,5 @@ class RewooAgent(BaseAgent):
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
citation=citation,
|
||||
metadata={"citation": citation},
|
||||
)
|
||||
|
@@ -41,7 +41,7 @@ class BaseTool(BaseComponent):
|
||||
args_schema = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if args_schema is not None:
|
||||
key_ = next(iter(args_schema.__fields__.keys()))
|
||||
key_ = next(iter(args_schema.model_fields.keys()))
|
||||
args_schema.validate({key_: tool_input})
|
||||
return tool_input
|
||||
else:
|
||||
@@ -121,9 +121,11 @@ class BaseTool(BaseComponent):
|
||||
|
||||
|
||||
class ComponentTool(BaseTool):
|
||||
"""
|
||||
A Tool based on another pipeline / BaseComponent to be used
|
||||
as its main entry point
|
||||
"""Wrapper around other BaseComponent to use it as a tool
|
||||
|
||||
Args:
|
||||
component: BaseComponent-based component to wrap
|
||||
postprocessor: Optional postprocessor for the component output
|
||||
"""
|
||||
|
||||
component: BaseComponent
|
||||
|
@@ -1,13 +1,11 @@
|
||||
from typing import AnyStr, Optional, Type, Union
|
||||
from typing import AnyStr, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kotaemon.llms import LLM, AzureChatOpenAI, ChatLLM
|
||||
from kotaemon.llms import BaseLLM
|
||||
|
||||
from .base import BaseTool, ToolException
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
class LLMArgs(BaseModel):
|
||||
query: str = Field(..., description="a search question or prompt")
|
||||
@@ -21,7 +19,7 @@ class LLMTool(BaseTool):
|
||||
"are confident in solving the problem "
|
||||
"yourself. Input can be any instruction."
|
||||
)
|
||||
llm: BaseLLM = AzureChatOpenAI.withx()
|
||||
llm: BaseLLM
|
||||
args_schema: Optional[Type[BaseModel]] = LLMArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Iterator
|
||||
|
||||
from theflow import Function, Node, Param, lazy
|
||||
|
||||
@@ -32,7 +33,9 @@ class BaseComponent(Function):
|
||||
return self.__call__(self.inflow.flow())
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs) -> Document | list[Document] | None:
|
||||
def run(
|
||||
self, *args, **kwargs
|
||||
) -> Document | list[Document] | Iterator[Document] | None:
|
||||
"""Run the component."""
|
||||
...
|
||||
|
||||
|
@@ -23,11 +23,13 @@ class Document(BaseDocument):
|
||||
store the raw content of the document. If specified, the class will use
|
||||
`content` to initialize the base llama_index class.
|
||||
|
||||
Args:
|
||||
content: the raw content of the document.
|
||||
Attributes:
|
||||
content: raw content of the document, can be anything
|
||||
source: id of the source of the Document. Optional.
|
||||
"""
|
||||
|
||||
content: Any
|
||||
source: Optional[str] = None
|
||||
|
||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||
if content is None:
|
||||
|
@@ -121,9 +121,12 @@ class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.embeddings
|
||||
try:
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
return langchain.emebddings.OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
@@ -148,9 +151,12 @@ class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.embeddings
|
||||
try:
|
||||
from langchain_community.embeddings import AzureOpenAIEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import AzureOpenAIEmbeddings
|
||||
|
||||
return langchain.embeddings.AzureOpenAIEmbeddings
|
||||
return AzureOpenAIEmbeddings
|
||||
|
||||
|
||||
class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
@@ -173,9 +179,12 @@ class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.embeddings
|
||||
try:
|
||||
from langchain_community.embeddings import CohereEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import CohereEmbeddings
|
||||
|
||||
return langchain.embeddings.CohereEmbeddings
|
||||
return CohereEmbeddings
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
@@ -192,6 +201,9 @@ class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.embeddings
|
||||
try:
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
|
||||
return langchain.embeddings.HuggingFaceBgeEmbeddings
|
||||
return HuggingFaceBgeEmbeddings
|
||||
|
@@ -11,6 +11,7 @@ from kotaemon.loaders import (
|
||||
MathpixPDFReader,
|
||||
OCRReader,
|
||||
PandasExcelReader,
|
||||
UnstructuredReader,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,8 +20,16 @@ class DocumentIngestor(BaseComponent):
|
||||
|
||||
Document types:
|
||||
- pdf
|
||||
- xlsx
|
||||
- docx
|
||||
- xlsx, xls
|
||||
- docx, doc
|
||||
|
||||
Args:
|
||||
pdf_mode: mode for pdf extraction, one of "normal", "mathpix", "ocr"
|
||||
- normal: parse pdf text
|
||||
- mathpix: parse pdf text using mathpix
|
||||
- ocr: parse pdf image using flax
|
||||
doc_parsers: list of document parsers to parse the document
|
||||
text_splitter: splitter to split the document into text nodes
|
||||
"""
|
||||
|
||||
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
|
||||
@@ -34,6 +43,9 @@ class DocumentIngestor(BaseComponent):
|
||||
"""Get appropriate readers for the input files based on file extension"""
|
||||
file_extractor: dict[str, AutoReader | BaseReader] = {
|
||||
".xlsx": PandasExcelReader(),
|
||||
".docx": UnstructuredReader(),
|
||||
".xls": UnstructuredReader(),
|
||||
".doc": UnstructuredReader(),
|
||||
}
|
||||
|
||||
if self.pdf_mode == "normal":
|
||||
|
@@ -64,11 +64,7 @@ class CitationPipeline(BaseComponent):
|
||||
|
||||
llm: BaseLLM
|
||||
|
||||
def run(
|
||||
self,
|
||||
context: str,
|
||||
question: str,
|
||||
) -> QuestionAnswer:
|
||||
def run(self, context: str, question: str):
|
||||
schema = QuestionAnswer.schema()
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
|
||||
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
|
||||
|
||||
from .citation import CitationPipeline
|
||||
@@ -21,6 +21,9 @@ class CitationQAPipeline(BaseComponent):
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda self: CitationPipeline(llm=self.llm)
|
||||
)
|
||||
|
||||
def _format_doc_text(self, text: str) -> str:
|
||||
"""Format the text of each document"""
|
||||
@@ -52,9 +55,7 @@ class CitationQAPipeline(BaseComponent):
|
||||
self.log_progress(".prompt", prompt=prompt)
|
||||
answer_text = self.llm(prompt).text
|
||||
if use_citation:
|
||||
# run citation pipeline
|
||||
citation_pipeline = CitationPipeline(llm=self.llm)
|
||||
citation = citation_pipeline(context=context, question=question)
|
||||
citation = self.citation_pipeline(context=context, question=question)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
|
@@ -23,17 +23,18 @@ class CohereReranking(BaseReranking):
|
||||
)
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
# output documents
|
||||
compressed_docs = []
|
||||
if len(documents) > 0: # to avoid empty api call
|
||||
_docs = [d.content for d in documents]
|
||||
results = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||
)
|
||||
for r in results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
if not documents: # to avoid empty api call
|
||||
return compressed_docs
|
||||
|
||||
_docs = [d.content for d in documents]
|
||||
results = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||
)
|
||||
for r in results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
||||
|
@@ -29,8 +29,19 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
|
||||
|
||||
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
def __init__(self, window_size: int = 3, **params):
|
||||
super().__init__(window_size=window_size, **params)
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 3,
|
||||
window_metadata_key: str = "window",
|
||||
original_text_metadata_key: str = "original_text",
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
window_size=window_size,
|
||||
window_metadata_key=window_metadata_key,
|
||||
original_text_metadata_key=original_text_metadata_key,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.node_parser import SentenceWindowNodeParser
|
||||
|
@@ -62,7 +62,7 @@ class VectorIndexing(BaseIndexing):
|
||||
embeddings = self.embedding(input_)
|
||||
self.vector_store.add(
|
||||
embeddings=embeddings,
|
||||
ids=[t.id_ for t in input_],
|
||||
ids=[t.doc_id for t in input_],
|
||||
)
|
||||
if self.doc_store:
|
||||
self.doc_store.add(input_)
|
||||
@@ -99,7 +99,7 @@ class VectorRetrieval(BaseRetrieval):
|
||||
)
|
||||
|
||||
emb: list[float] = self.embedding(text)[0].embedding
|
||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
|
||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
|
||||
docs = self.doc_store.get(ids)
|
||||
result = [
|
||||
RetrievedDocument(**doc.to_dict(), score=score)
|
||||
|
@@ -15,15 +15,23 @@ class LCChatMixin:
|
||||
"Please return the relevant Langchain class in in _get_lc_class"
|
||||
)
|
||||
|
||||
def __init__(self, **params):
|
||||
def __init__(self, stream: bool = False, **params):
|
||||
self._lc_class = self._get_lc_class()
|
||||
self._obj = self._lc_class(**params)
|
||||
self._kwargs: dict = params
|
||||
self._stream = stream
|
||||
|
||||
super().__init__()
|
||||
|
||||
def run(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
if self._stream:
|
||||
return self.stream(messages, **kwargs) # type: ignore
|
||||
return self.invoke(messages, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""Generate response from messages
|
||||
|
||||
@@ -68,6 +76,10 @@ class LCChatMixin:
|
||||
logits=[],
|
||||
)
|
||||
|
||||
def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs):
|
||||
for response in self._obj.stream(input=messages, **kwargs):
|
||||
yield LLMInterface(content=response.content)
|
||||
|
||||
def to_langchain_format(self):
|
||||
return self._obj
|
||||
|
||||
@@ -150,6 +162,9 @@ class AzureChatOpenAI(LCChatMixin, ChatLLM):
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.chat_models
|
||||
try:
|
||||
from langchain_community.chat_models import AzureChatOpenAI
|
||||
except ImportError:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
return langchain.chat_models.AzureChatOpenAI
|
||||
return AzureChatOpenAI
|
||||
|
@@ -186,6 +186,9 @@ class AzureOpenAI(LCCompletionMixin, LLM):
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.llms as langchain_llms
|
||||
try:
|
||||
from langchain_community.llms import AzureOpenAI
|
||||
except ImportError:
|
||||
from langchain.llms import AzureOpenAI
|
||||
|
||||
return langchain_llms.AzureOpenAI
|
||||
return AzureOpenAI
|
||||
|
@@ -26,11 +26,7 @@ class OCRReader(BaseReader):
|
||||
self.ocr_endpoint = endpoint
|
||||
self.use_ocr = use_ocr
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file_path: Path,
|
||||
**kwargs,
|
||||
) -> List[Document]:
|
||||
def load_data(self, file_path: Path, **kwargs) -> List[Document]:
|
||||
"""Load data using OCR reader
|
||||
|
||||
Args:
|
||||
@@ -41,23 +37,24 @@ class OCRReader(BaseReader):
|
||||
Returns:
|
||||
List[Document]: list of documents extracted from the PDF file
|
||||
"""
|
||||
# create input params for the requests
|
||||
content = open(file_path, "rb")
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
|
||||
file_path = Path(file_path).resolve()
|
||||
|
||||
with file_path.open("rb") as content:
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
|
||||
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
debug_path = kwargs.pop("debug_path", None)
|
||||
artifact_path = kwargs.pop("artifact_path", None)
|
||||
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
# read PDF through normal reader (unstructured)
|
||||
pdf_page_items = read_pdf_unstructured(file_path)
|
||||
# merge PDF text output with OCR output
|
||||
@@ -77,6 +74,9 @@ class OCRReader(BaseReader):
|
||||
"type": "table",
|
||||
"page_label": page_id + 1,
|
||||
"source": file_path.name,
|
||||
"file_path": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"filename": str(file_path),
|
||||
},
|
||||
metadata_template="",
|
||||
metadata_seperator="",
|
||||
@@ -91,6 +91,9 @@ class OCRReader(BaseReader):
|
||||
metadata={
|
||||
"page_label": page_id + 1,
|
||||
"source": file_path.name,
|
||||
"file_path": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"filename": str(file_path),
|
||||
},
|
||||
)
|
||||
for page_id, non_table_text in texts
|
||||
|
@@ -74,9 +74,10 @@ class UnstructuredReader(BaseReader):
|
||||
""" Process elements """
|
||||
docs = []
|
||||
file_name = Path(file).name
|
||||
file_path = str(Path(file).resolve())
|
||||
if split_documents:
|
||||
for node in elements:
|
||||
metadata = {"file_name": file_name}
|
||||
metadata = {"file_name": file_name, "file_path": file_path}
|
||||
if hasattr(node, "metadata"):
|
||||
"""Load metadata fields"""
|
||||
for field, val in vars(node.metadata).items():
|
||||
@@ -99,7 +100,7 @@ class UnstructuredReader(BaseReader):
|
||||
|
||||
else:
|
||||
text_chunks = [" ".join(str(el).split()) for el in elements]
|
||||
metadata = {"file_name": file_name}
|
||||
metadata = {"file_name": file_name, "file_path": file_path}
|
||||
|
||||
if additional_metadata is not None:
|
||||
metadata.update(additional_metadata)
|
||||
|
@@ -16,6 +16,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
elasticsearch_url: str = "http://localhost:9200",
|
||||
k1: float = 2.0,
|
||||
b: float = 0.75,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
@@ -31,7 +32,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
self.b = b
|
||||
|
||||
# Create an Elasticsearch client instance
|
||||
self.client = Elasticsearch(elasticsearch_url)
|
||||
self.client = Elasticsearch(elasticsearch_url, **kwargs)
|
||||
self.es_bulk = bulk
|
||||
# Define the index settings and mappings
|
||||
settings = {
|
||||
@@ -63,19 +64,16 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
ids: Optional[Union[List[str], str]] = None,
|
||||
**kwargs
|
||||
refresh_indices: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Add document into document store
|
||||
|
||||
Args:
|
||||
docs: list of documents to add
|
||||
ids: specify the ids of documents to add or
|
||||
use existing doc.doc_id
|
||||
refresh_indices: request Elasticsearch to update
|
||||
its index (default to True)
|
||||
ids: specify the ids of documents to add or use existing doc.doc_id
|
||||
refresh_indices: request Elasticsearch to update its index (default to True)
|
||||
"""
|
||||
refresh_indices = kwargs.pop("refresh_indices", True)
|
||||
|
||||
if ids and not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
if not isinstance(docs, list):
|
||||
@@ -120,7 +118,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
)
|
||||
return docs
|
||||
|
||||
def query(self, query: str, top_k: int = 10) -> List[Document]:
|
||||
def query(
|
||||
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||
) -> List[Document]:
|
||||
"""Search Elasticsearch docstore using search query (BM25)
|
||||
|
||||
Args:
|
||||
@@ -131,7 +131,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
Returns:
|
||||
List[Document]: List of result documents
|
||||
"""
|
||||
query_dict = {"query": {"match": {"content": query}}, "size": top_k}
|
||||
query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k}
|
||||
if doc_ids:
|
||||
query_dict["query"]["match"]["_id"] = {"values": doc_ids}
|
||||
return self.query_raw(query_dict)
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
|
@@ -74,6 +74,11 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
"""Load document store from path"""
|
||||
with open(path) as f:
|
||||
store = json.load(f)
|
||||
# TODO: save and load aren't lossless. A Document-subclass will lose
|
||||
# information. Need to edit the `to_dict` and `from_dict` methods in
|
||||
# the Document class.
|
||||
# For better query support, utilize SQLite as the default document store.
|
||||
# Also, for portability, use SQLAlchemy for document store.
|
||||
self._store = {key: Document.from_dict(value) for key, value in store.items()}
|
||||
|
||||
def __persist_flow__(self):
|
||||
|
@@ -15,6 +15,18 @@ class SimpleFileDocumentStore(InMemoryDocumentStore):
|
||||
if path is not None and Path(path).is_file():
|
||||
self.load(path)
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
"""Get document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
|
||||
for doc_id in ids:
|
||||
if doc_id not in self._store:
|
||||
self.load(self._path)
|
||||
break
|
||||
|
||||
return [self._store[doc_id] for doc_id in ids]
|
||||
|
||||
def add(
|
||||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
|
@@ -76,8 +76,15 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
||||
"Require `_li_class` to set a VectorStore class from LlamarIndex"
|
||||
)
|
||||
|
||||
from dataclasses import fields
|
||||
|
||||
self._client = self._li_class(*args, **kwargs)
|
||||
|
||||
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
|
||||
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
|
||||
if key in self._vsq_kwargs:
|
||||
self._vsq_kwargs.remove(key)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_"):
|
||||
return super().__setattr__(name, value)
|
||||
@@ -122,13 +129,35 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[list[float]], list[float], list[str]]:
|
||||
"""Return the top k most similar vector embeddings
|
||||
|
||||
Args:
|
||||
embedding: List of embeddings
|
||||
top_k: Number of most similar embeddings to return
|
||||
ids: List of ids of the embeddings to be queried
|
||||
kwargs: extra query parameters. Depending on the name, these parameters
|
||||
will be used when constructing the VectorStoreQuery object or when
|
||||
performing querying of the underlying vector store.
|
||||
|
||||
Returns:
|
||||
the matched embeddings, the similarity scores, and the ids
|
||||
"""
|
||||
vsq_kwargs = {}
|
||||
vs_kwargs = {}
|
||||
for kwkey, kwvalue in kwargs.items():
|
||||
if kwkey in self._vsq_kwargs:
|
||||
vsq_kwargs[kwkey] = kwvalue
|
||||
else:
|
||||
vs_kwargs[kwkey] = kwvalue
|
||||
|
||||
output = self._client.query(
|
||||
query=VectorStoreQuery(
|
||||
query_embedding=embedding,
|
||||
similarity_top_k=top_k,
|
||||
node_ids=ids,
|
||||
**kwargs,
|
||||
**vsq_kwargs,
|
||||
),
|
||||
**vs_kwargs,
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
|
@@ -64,7 +64,7 @@ class ChromaVectorStore(LlamaIndexVectorStore):
|
||||
ids: List of ids of the embeddings to be deleted
|
||||
kwargs: meant for vectorstore-specific parameters
|
||||
"""
|
||||
self._client._collection.delete(ids=ids)
|
||||
self._client.client.delete(ids=ids)
|
||||
|
||||
def delete_collection(self, collection_name: Optional[str] = None):
|
||||
"""Delete entire collection under specified name from vector stores
|
||||
|
Reference in New Issue
Block a user