From d79b3744cbcc8a84038e46583d3030e4538c81a6 Mon Sep 17 00:00:00 2001 From: "Nguyen Trung Duc (john)" Date: Mon, 13 Nov 2023 15:10:18 +0700 Subject: [PATCH] Simplify the `BaseComponent` inteface (#64) This change remove `BaseComponent`'s: - run_raw - run_batch_raw - run_document - run_batch_document - is_document - is_batch Each component is expected to support multiple types of inputs and a single type of output. Since we want the component to work out-of-the-box with both standardized and customized use cases, supporting multiple types of inputs are expected. At the same time, to reduce the complexity of understanding how to use a component, we restrict a component to only have a single output type. To accommodate these changes, we also refactor some components to remove their run_raw, run_batch_raw... methods, and to decide the common output interface for those components. Tests are updated accordingly. Commit changes: * Add kwargs to vector store's query * Simplify the BaseComponent * Update tests * Remove support for Python 3.8 and 3.9 * Bump version 0.3.0 * Fix github PR caching still use old environment after bumping version --------- Co-authored-by: ian --- .github/workflows/unit-test.yaml | 4 +- knowledgehub/__init__.py | 2 +- knowledgehub/base.py | 70 ------------- knowledgehub/base/__init__.py | 3 + knowledgehub/base/component.py | 35 +++++++ knowledgehub/composite/linear.py | 4 +- knowledgehub/config.py | 0 knowledgehub/embeddings/base.py | 55 ++++------ knowledgehub/llms/chats/base.py | 89 ++++++++-------- knowledgehub/llms/completions/base.py | 48 +++++---- knowledgehub/pipelines/indexing.py | 49 ++++----- knowledgehub/pipelines/retrieving.py | 84 ++++----------- knowledgehub/pipelines/tools/base.py | 19 +--- knowledgehub/post_processing/extractor.py | 120 +++++++--------------- knowledgehub/prompt/base.py | 18 ---- knowledgehub/schema.py | 0 knowledgehub/vectorstores/base.py | 1 + setup.py | 2 +- tests/test_composite.py | 71 +++++++++---- tests/test_embedding_models.py | 19 ++-- tests/test_indexing_retrieval.py | 7 +- tests/test_llms_chat_models.py | 12 --- tests/test_llms_completion_models.py | 10 -- tests/test_post_processing.py | 14 +-- tests/test_prompt.py | 2 +- 25 files changed, 280 insertions(+), 458 deletions(-) delete mode 100644 knowledgehub/base.py create mode 100644 knowledgehub/base/__init__.py create mode 100644 knowledgehub/base/component.py delete mode 100644 knowledgehub/config.py delete mode 100644 knowledgehub/schema.py diff --git a/.github/workflows/unit-test.yaml b/.github/workflows/unit-test.yaml index ca89a4d..5b5a503 100644 --- a/.github/workflows/unit-test.yaml +++ b/.github/workflows/unit-test.yaml @@ -16,7 +16,7 @@ jobs: shell: ${{ matrix.shell }} strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] include: - os: ubuntu-latest shell: bash @@ -81,7 +81,7 @@ jobs: steps.check-cache-hit.outputs.check != 'true' run: | python -m pip install --upgrade pip - pip install -e .[dev] + pip install --ignore-installed -e .[dev] - name: New dependencies cache for key ${{ steps.restore-dependencies.outputs.cache-primary-key }} if: | diff --git a/knowledgehub/__init__.py b/knowledgehub/__init__.py index 943fd84..404c9ab 100644 --- a/knowledgehub/__init__.py +++ b/knowledgehub/__init__.py @@ -22,4 +22,4 @@ try: except ImportError: pass -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/knowledgehub/base.py b/knowledgehub/base.py deleted file mode 100644 index db7fd88..0000000 --- a/knowledgehub/base.py +++ /dev/null @@ -1,70 +0,0 @@ -from abc import abstractmethod - -from theflow.base import Compose - - -class BaseComponent(Compose): - """Base class for component - - A component is a class that can be used to compose a pipeline. To use the - component, you should implement the following methods: - - - run_raw: run on raw input - - run_batch_raw: run on batch of raw input - - run_document: run on document - - run_batch_document: run on batch of documents - - is_document: check if input is document - - is_batch: check if input is batch - """ - - inflow = None - - def flow(self): - if self.inflow is None: - raise ValueError("No inflow provided.") - - if not isinstance(self.inflow, BaseComponent): - raise ValueError( - f"inflow must be a BaseComponent, found {type(self.inflow)}" - ) - - return self.__call__(self.inflow.flow()) - - @abstractmethod - def run_raw(self, *args, **kwargs): - ... - - @abstractmethod - def run_batch_raw(self, *args, **kwargs): - ... - - @abstractmethod - def run_document(self, *args, **kwargs): - ... - - @abstractmethod - def run_batch_document(self, *args, **kwargs): - ... - - @abstractmethod - def is_document(self, *args, **kwargs) -> bool: - ... - - @abstractmethod - def is_batch(self, *args, **kwargs) -> bool: - ... - - def run(self, *args, **kwargs): - """Run the component.""" - - is_document = self.is_document(*args, **kwargs) - is_batch = self.is_batch(*args, **kwargs) - - if is_document and is_batch: - return self.run_batch_document(*args, **kwargs) - elif is_document and not is_batch: - return self.run_document(*args, **kwargs) - elif not is_document and is_batch: - return self.run_batch_raw(*args, **kwargs) - else: - return self.run_raw(*args, **kwargs) diff --git a/knowledgehub/base/__init__.py b/knowledgehub/base/__init__.py new file mode 100644 index 0000000..f83cbe5 --- /dev/null +++ b/knowledgehub/base/__init__.py @@ -0,0 +1,3 @@ +from .component import BaseComponent + +__all__ = ["BaseComponent"] diff --git a/knowledgehub/base/component.py b/knowledgehub/base/component.py new file mode 100644 index 0000000..25505a7 --- /dev/null +++ b/knowledgehub/base/component.py @@ -0,0 +1,35 @@ +from abc import abstractmethod + +from theflow.base import Compose + + +class BaseComponent(Compose): + """A component is a class that can be used to compose a pipeline + + Benefits of component: + - Auto caching, logging + - Allow deployment + + For each component, the spirit is: + - Tolerate multiple input types, e.g. str, Document, List[str], List[Document] + - Enforce single output type. Hence, the output type of a component should be + as generic as possible. + """ + + inflow = None + + def flow(self): + if self.inflow is None: + raise ValueError("No inflow provided.") + + if not isinstance(self.inflow, BaseComponent): + raise ValueError( + f"inflow must be a BaseComponent, found {type(self.inflow)}" + ) + + return self.__call__(self.inflow.flow()) + + @abstractmethod + def run(self, *args, **kwargs): + """Run the component.""" + ... diff --git a/knowledgehub/composite/linear.py b/knowledgehub/composite/linear.py index 288ba62..64c6d55 100644 --- a/knowledgehub/composite/linear.py +++ b/knowledgehub/composite/linear.py @@ -70,7 +70,7 @@ class SimpleLinearPipeline(BaseComponent): prompt = self.prompt(**prompt_kwargs) llm_output = self.llm(prompt.text, **llm_kwargs) if self.post_processor is not None: - final_output = self.post_processor(llm_output, **post_processor_kwargs) + final_output = self.post_processor(llm_output, **post_processor_kwargs)[0] else: final_output = llm_output @@ -143,7 +143,7 @@ class GatedLinearPipeline(SimpleLinearPipeline): if condition_text is None: raise ValueError("`condition_text` must be provided") - if self.condition(condition_text): + if self.condition(condition_text)[0]: return super().run( llm_kwargs=llm_kwargs, post_processor_kwargs=post_processor_kwargs, diff --git a/knowledgehub/config.py b/knowledgehub/config.py deleted file mode 100644 index e69de29..0000000 diff --git a/knowledgehub/embeddings/base.py b/knowledgehub/embeddings/base.py index 3ea315f..e51aabf 100644 --- a/knowledgehub/embeddings/base.py +++ b/knowledgehub/embeddings/base.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from abc import abstractmethod -from typing import List, Type +from typing import Type from langchain.schema.embeddings import Embeddings as LCEmbeddings from theflow import Param @@ -10,33 +12,11 @@ from ..documents.base import Document class BaseEmbeddings(BaseComponent): @abstractmethod - def run_raw(self, text: str) -> List[float]: + def run( + self, text: str | list[str] | Document | list[Document] + ) -> list[list[float]]: ... - @abstractmethod - def run_batch_raw(self, text: List[str]) -> List[List[float]]: - ... - - @abstractmethod - def run_document(self, text: Document) -> List[float]: - ... - - @abstractmethod - def run_batch_document(self, text: List[Document]) -> List[List[float]]: - ... - - def is_document(self, text) -> bool: - if isinstance(text, Document): - return True - elif isinstance(text, List) and isinstance(text[0], Document): - return True - return False - - def is_batch(self, text) -> bool: - if isinstance(text, list): - return True - return False - class LangchainEmbeddings(BaseEmbeddings): _lc_class: Type[LCEmbeddings] @@ -64,14 +44,19 @@ class LangchainEmbeddings(BaseEmbeddings): def agent(self): return self._lc_class(**self._kwargs) - def run_raw(self, text: str) -> List[float]: - return self.agent.embed_query(text) # type: ignore + def run(self, text) -> list[list[float]]: + input_: list[str] = [] + if not isinstance(text, list): + text = [text] - def run_batch_raw(self, text: List[str]) -> List[List[float]]: - return self.agent.embed_documents(text) # type: ignore + for item in text: + if isinstance(item, str): + input_.append(item) + elif isinstance(item, Document): + input_.append(item.text) + else: + raise ValueError( + f"Invalid input type {type(item)}, should be str or Document" + ) - def run_document(self, text: Document) -> List[float]: - return self.agent.embed_query(text.text) # type: ignore - - def run_batch_document(self, text: List[Document]) -> List[List[float]]: - return self.agent.embed_documents([each.text for each in text]) # type: ignore + return self.agent.embed_documents(input_) diff --git a/knowledgehub/llms/chats/base.py b/knowledgehub/llms/chats/base.py index 5fc4b4e..301b548 100644 --- a/knowledgehub/llms/chats/base.py +++ b/knowledgehub/llms/chats/base.py @@ -1,13 +1,16 @@ -from typing import List, Type, TypeVar +from __future__ import annotations -from langchain.schema.language_model import BaseLanguageModel +import logging +from typing import Type + +from langchain.chat_models.base import BaseChatModel from langchain.schema.messages import BaseMessage, HumanMessage from theflow.base import Param from ...base import BaseComponent from ..base import LLMInterface -Message = TypeVar("Message", bound=BaseMessage) +logger = logging.getLogger(__name__) class ChatLLM(BaseComponent): @@ -25,7 +28,7 @@ class ChatLLM(BaseComponent): class LangchainChatLLM(ChatLLM): - _lc_class: Type[BaseLanguageModel] + _lc_class: Type[BaseChatModel] def __init__(self, **params): if self._lc_class is None: @@ -41,60 +44,62 @@ class LangchainChatLLM(ChatLLM): super().__init__(**params) @Param.auto(cache=False) - def agent(self) -> BaseLanguageModel: + def agent(self) -> BaseChatModel: return self._lc_class(**self._kwargs) - def run_raw(self, text: str, **kwargs) -> LLMInterface: - message = HumanMessage(content=text) - return self.run_document([message], **kwargs) + def run( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> LLMInterface: + """Generate response from messages - def run_batch_raw(self, text: List[str], **kwargs) -> List[LLMInterface]: - inputs = [[HumanMessage(content=each)] for each in text] - return self.run_batch_document(inputs, **kwargs) + Args: + messages: history of messages to generate response from + **kwargs: additional arguments to pass to the langchain chat model - def run_document(self, text: List[Message], **kwargs) -> LLMInterface: - pred = self.agent.generate([text], **kwargs) # type: ignore + Returns: + LLMInterface: generated response + """ + input_: list[BaseMessage] = [] + + if isinstance(messages, str): + input_ = [HumanMessage(content=messages)] + elif isinstance(messages, BaseMessage): + input_ = [messages] + else: + input_ = messages + + pred = self.agent.generate(messages=[input_], **kwargs) all_text = [each.text for each in pred.generations[0]] + + completion_tokens, total_tokens, prompt_tokens = 0, 0, 0 + try: + if pred.llm_output is not None: + completion_tokens = pred.llm_output["token_usage"]["completion_tokens"] + total_tokens = pred.llm_output["token_usage"]["total_tokens"] + prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"] + except Exception: + logger.warning( + f"Cannot get token usage from LLM output for {self._lc_class.__name__}" + ) + return LLMInterface( text=all_text[0] if len(all_text) > 0 else "", candidates=all_text, - completion_tokens=pred.llm_output["token_usage"]["completion_tokens"], - total_tokens=pred.llm_output["token_usage"]["total_tokens"], - prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"], + completion_tokens=completion_tokens, + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, logits=[], ) - def run_batch_document( - self, text: List[List[Message]], **kwargs - ) -> List[LLMInterface]: - outputs = [] - for each_text in text: - outputs.append(self.run_document(each_text, **kwargs)) - return outputs - - def is_document(self, text, **kwargs) -> bool: - if isinstance(text, str): - return False - elif isinstance(text, List) and isinstance(text[0], str): - return False - return True - - def is_batch(self, text, **kwargs) -> bool: - if isinstance(text, str): - return False - elif isinstance(text, List): - if isinstance(text[0], BaseMessage): - return False - return True - def __setattr__(self, name, value): if name in self._lc_class.__fields__: + self._kwargs[name] = value setattr(self.agent, name, value) else: super().__setattr__(name, value) def __getattr__(self, name): if name in self._lc_class.__fields__: - getattr(self.agent, name) - else: - super().__getattr__(name) + return getattr(self.agent, name) + + return super().__getattr__(name) # type: ignore diff --git a/knowledgehub/llms/completions/base.py b/knowledgehub/llms/completions/base.py index 05d9bc2..238f0f2 100644 --- a/knowledgehub/llms/completions/base.py +++ b/knowledgehub/llms/completions/base.py @@ -1,18 +1,21 @@ -from typing import List, Type +import logging +from typing import Type -from langchain.schema.language_model import BaseLanguageModel +from langchain.llms.base import BaseLLM from theflow.base import Param from ...base import BaseComponent from ..base import LLMInterface +logger = logging.getLogger(__name__) + class LLM(BaseComponent): pass class LangchainLLM(LLM): - _lc_class: Type[BaseLanguageModel] + _lc_class: Type[BaseLLM] def __init__(self, **params): if self._lc_class is None: @@ -31,38 +34,33 @@ class LangchainLLM(LLM): def agent(self): return self._lc_class(**self._kwargs) - def run_raw(self, text: str) -> LLMInterface: + def run(self, text: str) -> LLMInterface: pred = self.agent.generate([text]) all_text = [each.text for each in pred.generations[0]] + + completion_tokens, total_tokens, prompt_tokens = 0, 0, 0 + try: + if pred.llm_output is not None: + completion_tokens = pred.llm_output["token_usage"]["completion_tokens"] + total_tokens = pred.llm_output["token_usage"]["total_tokens"] + prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"] + except Exception: + logger.warning( + f"Cannot get token usage from LLM output for {self._lc_class.__name__}" + ) + return LLMInterface( text=all_text[0] if len(all_text) > 0 else "", candidates=all_text, - completion_tokens=pred.llm_output["token_usage"]["completion_tokens"], - total_tokens=pred.llm_output["token_usage"]["total_tokens"], - prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"], + completion_tokens=completion_tokens, + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, logits=[], ) - def run_batch_raw(self, text: List[str]) -> List[LLMInterface]: - outputs = [] - for each_text in text: - outputs.append(self.run_raw(each_text)) - return outputs - - def run_document(self, text: str) -> LLMInterface: - return self.run_raw(text) - - def run_batch_document(self, text: List[str]) -> List[LLMInterface]: - return self.run_batch_raw(text) - - def is_document(self, text) -> bool: - return False - - def is_batch(self, text) -> bool: - return False if isinstance(text, str) else True - def __setattr__(self, name, value): if name in self._lc_class.__fields__: + self._kwargs[name] = value setattr(self.agent, name, value) else: super().__setattr__(name, value) diff --git a/knowledgehub/pipelines/indexing.py b/knowledgehub/pipelines/indexing.py index fe69667..03e1d53 100644 --- a/knowledgehub/pipelines/indexing.py +++ b/knowledgehub/pipelines/indexing.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import uuid from pathlib import Path -from typing import List, Union from theflow import Node, Param @@ -26,44 +27,34 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent): vector_store: Param[BaseVectorStore] = Param() doc_store: Param[BaseDocumentStore] = Param() embedding: Node[BaseEmbeddings] = Node() - # TODO: refer to llama_index's storage as well - def run_raw(self, text: str) -> None: - document = Document(text=text, id_=str(uuid.uuid4())) - self.run_batch_document([document]) + def run(self, text: str | list[str] | Document | list[Document]) -> None: + input_: list[Document] = [] + if not isinstance(text, list): + text = [text] - def run_batch_raw(self, text: List[str]) -> None: - documents = [Document(text=t, id_=str(uuid.uuid4())) for t in text] - self.run_batch_document(documents) + for item in text: + if isinstance(item, str): + input_.append(Document(text=item, id_=str(uuid.uuid4()))) + elif isinstance(item, Document): + input_.append(item) + else: + raise ValueError( + f"Invalid input type {type(item)}, should be str or Document" + ) - def run_document(self, text: Document) -> None: - self.run_batch_document([text]) - - def run_batch_document(self, text: List[Document]) -> None: - embeddings = self.embedding(text) + embeddings = self.embedding(input_) self.vector_store.add( embeddings=embeddings, - ids=[t.id_ for t in text], + ids=[t.id_ for t in input_], ) if self.doc_store: - self.doc_store.add(text) - - def is_document(self, text) -> bool: - if isinstance(text, Document): - return True - elif isinstance(text, List) and isinstance(text[0], Document): - return True - return False - - def is_batch(self, text) -> bool: - if isinstance(text, list): - return True - return False + self.doc_store.add(input_) def save( self, - path: Union[str, Path], + path: str | Path, vectorstore_fname: str = VECTOR_STORE_FNAME, docstore_fname: str = DOC_STORE_FNAME, ): @@ -80,7 +71,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent): def load( self, - path: Union[str, Path], + path: str | Path, vectorstore_fname: str = VECTOR_STORE_FNAME, docstore_fname: str = DOC_STORE_FNAME, ): diff --git a/knowledgehub/pipelines/retrieving.py b/knowledgehub/pipelines/retrieving.py index eb6656b..5643f5e 100644 --- a/knowledgehub/pipelines/retrieving.py +++ b/knowledgehub/pipelines/retrieving.py @@ -1,6 +1,6 @@ -from abc import abstractmethod +from __future__ import annotations + from pathlib import Path -from typing import List, Union from theflow import Node, Param @@ -14,31 +14,7 @@ VECTOR_STORE_FNAME = "vectorstore" DOC_STORE_FNAME = "docstore" -class BaseRetrieval(BaseComponent): - """Define the base interface of a retrieval pipeline""" - - @abstractmethod - def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]: - ... - - @abstractmethod - def run_batch_raw( - self, text: List[str], top_k: int = 1 - ) -> List[List[RetrievedDocument]]: - ... - - @abstractmethod - def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]: - ... - - @abstractmethod - def run_batch_document( - self, text: List[Document], top_k: int = 1 - ) -> List[List[RetrievedDocument]]: - ... - - -class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval): +class RetrieveDocumentFromVectorStorePipeline(BaseComponent): """Retrieve list of documents from vector store""" vector_store: Param[BaseVectorStore] = Param() @@ -46,53 +22,33 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval): embedding: Node[BaseEmbeddings] = Node() # TODO: refer to llama_index's storage as well - def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]: - return self.run_batch_raw([text], top_k=top_k)[0] + def run(self, text: str | Document, top_k: int = 1) -> list[RetrievedDocument]: + """Retrieve a list of documents from vector store - def run_batch_raw( - self, text: List[str], top_k: int = 1 - ) -> List[List[RetrievedDocument]]: + Args: + text: the text to retrieve similar documents + + Returns: + list[RetrievedDocument]: list of retrieved documents + """ if self.doc_store is None: raise ValueError( "doc_store is not provided. Please provide a doc_store to " "retrieve the documents" ) - result = [] - for each_text in text: - emb = self.embedding(each_text) - _, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k) - docs = self.doc_store.get(ids) - each_result = [ - RetrievedDocument(**doc.to_dict(), score=score) - for doc, score in zip(docs, scores) - ] - result.append(each_result) + emb: list[float] = self.embedding(text)[0] + _, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k) + docs = self.doc_store.get(ids) + result = [ + RetrievedDocument(**doc.to_dict(), score=score) + for doc, score in zip(docs, scores) + ] return result - def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]: - return self.run_raw(text.text, top_k) - - def run_batch_document( - self, text: List[Document], top_k: int = 1 - ) -> List[List[RetrievedDocument]]: - return self.run_batch_raw(text=[t.text for t in text], top_k=top_k) - - def is_document(self, text, *args, **kwargs) -> bool: - if isinstance(text, Document): - return True - elif isinstance(text, List) and isinstance(text[0], Document): - return True - return False - - def is_batch(self, text, *args, **kwargs) -> bool: - if isinstance(text, list): - return True - return False - def save( self, - path: Union[str, Path], + path: str | Path, vectorstore_fname: str = VECTOR_STORE_FNAME, docstore_fname: str = DOC_STORE_FNAME, ): @@ -109,7 +65,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval): def load( self, - path: Union[str, Path], + path: str | Path, vectorstore_fname: str = VECTOR_STORE_FNAME, docstore_fname: str = DOC_STORE_FNAME, ): diff --git a/knowledgehub/pipelines/tools/base.py b/knowledgehub/pipelines/tools/base.py index fb5688e..413b362 100644 --- a/knowledgehub/pipelines/tools/base.py +++ b/knowledgehub/pipelines/tools/base.py @@ -92,7 +92,7 @@ class BaseTool(BaseComponent): """Convert this tool to Langchain format to use with its agent""" return LCTool(name=self.name, description=self.description, func=self.run) - def run_raw( + def run( self, tool_input: Union[str, Dict], verbose: Optional[bool] = None, @@ -110,23 +110,6 @@ class BaseTool(BaseComponent): else: return observation - def run_document(self, *args, **kwargs): - pass - - def run_batch_raw(self, *args, **kwargs): - pass - - def run_batch_document(self, *args, **kwargs): - pass - - def is_document(self, *args, **kwargs) -> bool: - """Tool does not support processing document""" - return False - - def is_batch(self, *args, **kwargs) -> bool: - """Tool does not support processing batch""" - return False - @classmethod def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool": """Wrapper for Langchain Tool""" diff --git a/knowledgehub/post_processing/extractor.py b/knowledgehub/post_processing/extractor.py index ee5633b..fbc2285 100644 --- a/knowledgehub/post_processing/extractor.py +++ b/knowledgehub/post_processing/extractor.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import re -from typing import Callable, Dict, List, Union +from typing import Callable from theflow import Param @@ -12,7 +14,7 @@ class ExtractorOutput(Document): Represents the output of an extractor. """ - matches: List[str] + matches: list[str] class RegexExtractor(BaseComponent): @@ -28,18 +30,18 @@ class RegexExtractor(BaseComponent): class Config: middleware_switches = {"theflow.middleware.CachingMiddleware": False} - pattern: List[str] - output_map: Union[Dict[str, str], Callable[[str], str]] = Param( + pattern: list[str] + output_map: dict[str, str] | Callable[[str], str] = Param( default_callback=lambda *_: {} ) - def __init__(self, pattern: Union[str, List[str]], **kwargs): + def __init__(self, pattern: str | list[str], **kwargs): if isinstance(pattern, str): pattern = [pattern] super().__init__(pattern=pattern, **kwargs) @staticmethod - def run_raw_static(pattern: str, text: str) -> List[str]: + def run_raw_static(pattern: str, text: str) -> list[str]: """ Finds all non-overlapping occurrences of a pattern in a string. @@ -86,9 +88,9 @@ class RegexExtractor(BaseComponent): Returns: ExtractorOutput: The processed output as a list of ExtractorOutput. """ - output = sum( + output: list[str] = sum( [self.run_raw_static(p, text) for p in self.pattern], [] - ) # type: List[str] + ) output = [self.map_output(text, self.output_map) for text in output] return ExtractorOutput( @@ -97,100 +99,48 @@ class RegexExtractor(BaseComponent): metadata={"origin": "RegexExtractor"}, ) - def run_batch_raw(self, text_batch: List[str]) -> List[ExtractorOutput]: - """ - Runs a batch of raw text inputs through the `run_raw()` method and returns the - output for each input. + def run( + self, text: str | list[str] | Document | list[Document] + ) -> list[ExtractorOutput]: + """Match the input against a pattern and return the output for each input Parameters: - text_batch (List[str]): A list of raw text inputs to process. + text: contains the input string to be processed Returns: - List[ExtractorOutput]: A list containing the output for each input in the - batch. - """ - batch_output = [self.run_raw(each_text) for each_text in text_batch] - - return batch_output - - def run_document(self, document: Document) -> ExtractorOutput: - """ - Run the document through the regex extractor and return an extracted document. - - Args: - document (Document): The input document. - - Returns: - ExtractorOutput: The extracted content. - """ - return self.run_raw(document.text) - - def run_batch_document( - self, document_batch: List[Document] - ) -> List[ExtractorOutput]: - """ - Runs a batch of documents through the `run_document` function and returns the - output for each document. - - - Parameters: - document_batch (List[Document]): A list of Document objects representing the - batch of documents to process. - - Returns: - List[ExtractorOutput]: A list contains the output ExtractorOutput for each - input Document in the batch. + A list contains the output ExtractorOutput for each input Example: document1 = Document(...) document2 = Document(...) document_batch = [document1, document2] - batch_output = self.run_batch_document(document_batch) + batch_output = self(document_batch) # batch_output will be [output1_document1, output1_document2] """ + # TODO: this conversion seems common + input_: list[str] = [] + if not isinstance(text, list): + text = [text] - batch_output = [ - self.run_document(each_document) for each_document in document_batch - ] + for item in text: + if isinstance(item, str): + input_.append(item) + elif isinstance(item, Document): + input_.append(item.text) + else: + raise ValueError( + f"Invalid input type {type(item)}, should be str or Document" + ) - return batch_output + output = [] + for each_input in input_: + output.append(self.run_raw(each_input)) - def is_document(self, text) -> bool: - """ - Check if the given text is an instance of the Document class. - - Args: - text: The text to check. - - Returns: - bool: True if the text is an instance of Document, False otherwise. - """ - if isinstance(text, Document): - return True - - return False - - def is_batch(self, text) -> bool: - """ - Check if the given text is a batch of documents. - - Parameters: - text (List): The text to be checked. - - Returns: - bool: True if the text is a batch of documents, False otherwise. - """ - if not isinstance(text, List): - return False - - if len(set(self.is_document(each_text) for each_text in text)) <= 1: - return True - - return False + return output class FirstMatchRegexExtractor(RegexExtractor): - pattern: List[str] + pattern: list[str] def run_raw(self, text: str) -> ExtractorOutput: for p in self.pattern: diff --git a/knowledgehub/prompt/base.py b/knowledgehub/prompt/base.py index 48b370c..0459fc3 100644 --- a/knowledgehub/prompt/base.py +++ b/knowledgehub/prompt/base.py @@ -174,23 +174,5 @@ class BasePromptComponent(BaseComponent): text = self.template.populate(**prepared_kwargs) return Document(text=text, metadata={"origin": "PromptComponent"}) - def run_raw(self, *args, **kwargs): - pass - - def run_batch_raw(self, *args, **kwargs): - pass - - def run_document(self, *args, **kwargs): - pass - - def run_batch_document(self, *args, **kwargs): - pass - - def is_document(self, *args, **kwargs): - pass - - def is_batch(self, *args, **kwargs): - pass - def flow(self): return self.__call__() diff --git a/knowledgehub/schema.py b/knowledgehub/schema.py deleted file mode 100644 index e69de29..0000000 diff --git a/knowledgehub/vectorstores/base.py b/knowledgehub/vectorstores/base.py index 51f85b3..0760f8d 100644 --- a/knowledgehub/vectorstores/base.py +++ b/knowledgehub/vectorstores/base.py @@ -59,6 +59,7 @@ class BaseVectorStore(ABC): embedding: List[float], top_k: int = 1, ids: Optional[List[str]] = None, + **kwargs, ) -> Tuple[List[List[float]], List[float], List[str]]: """Return the top k most similar vector embeddings diff --git a/setup.py b/setup.py index bb20417..f1f9a4b 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setuptools.setup( ], }, entry_points={"console_scripts": ["kh=kotaemon.cli:main"]}, - python_requires=">=3.8", + python_requires=">=3.10", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/tests/test_composite.py b/tests/test_composite.py index 66f7e94..dadffb5 100644 --- a/tests/test_composite.py +++ b/tests/test_composite.py @@ -1,4 +1,7 @@ +from copy import deepcopy + import pytest +from openai.types.chat.chat_completion import ChatCompletion from kotaemon.composite import ( GatedBranchingPipeline, @@ -10,6 +13,29 @@ from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.post_processing.extractor import RegexExtractor from kotaemon.prompt.base import BasePromptComponent +_openai_chat_completion_response = ChatCompletion.parse_obj( + { + "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", + "object": "chat.completion", + "created": 1692338378, + "model": "gpt-35-turbo", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "This is a test 123", + "finish_reason": "length", + "logprobs": None, + }, + } + ], + "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, + } +) + @pytest.fixture def mock_llm(): @@ -19,7 +45,6 @@ def mock_llm(): openai_api_version="OPENAI_API_VERSION", deployment_name="dummy-q2-gpt35", temperature=0, - request_timeout=600, ) @@ -61,11 +86,12 @@ def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_process def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline): - openai_mocker = mocker.patch.object( - AzureChatOpenAI, "run", return_value="This is a test 123" + openai_mocker = mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=_openai_chat_completion_response, ) - result = mock_simple_linear_pipeline.run(value="abc") + result = mock_simple_linear_pipeline(value="abc") assert result.text == "123" assert openai_mocker.call_count == 1 @@ -74,11 +100,12 @@ def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline): def test_gated_linear_pipeline_run_positive( mocker, mock_gated_linear_pipeline_positive ): - openai_mocker = mocker.patch.object( - AzureChatOpenAI, "run", return_value="This is a test 123." + openai_mocker = mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=_openai_chat_completion_response, ) - result = mock_gated_linear_pipeline_positive.run( + result = mock_gated_linear_pipeline_positive( value="abc", condition_text="positive condition" ) @@ -89,11 +116,12 @@ def test_gated_linear_pipeline_run_positive( def test_gated_linear_pipeline_run_negative( mocker, mock_gated_linear_pipeline_positive ): - openai_mocker = mocker.patch.object( - AzureChatOpenAI, "run", return_value="This is a test 123." + openai_mocker = mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=_openai_chat_completion_response, ) - result = mock_gated_linear_pipeline_positive.run( + result = mock_gated_linear_pipeline_positive( value="abc", condition_text="negative condition" ) @@ -102,14 +130,14 @@ def test_gated_linear_pipeline_run_negative( def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline): - openai_mocker = mocker.patch.object( - AzureChatOpenAI, - "run", - side_effect=[ - "This is a test 123.", - "a quick brown fox", - "jumps over the lazy dog 456", - ], + response0: ChatCompletion = _openai_chat_completion_response + response1: ChatCompletion = deepcopy(_openai_chat_completion_response) + response1.choices[0].message.content = "a quick brown fox" + response2: ChatCompletion = deepcopy(_openai_chat_completion_response) + response2.choices[0].message.content = "jumps over the lazy dog 456" + openai_mocker = mocker.patch( + "openai.resources.chat.completions.Completions.create", + side_effect=[response0, response1, response2], ) pipeline = SimpleBranchingPipeline() for _ in range(3): @@ -126,8 +154,11 @@ def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline): def test_simple_gated_branching_pipeline_run( mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative ): - openai_mocker = mocker.patch.object( - AzureChatOpenAI, "run", return_value="a quick brown fox" + response0: ChatCompletion = deepcopy(_openai_chat_completion_response) + response0.choices[0].message.content = "a quick brown fox" + openai_mocker = mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=response0, ) pipeline = GatedBranchingPipeline() diff --git a/tests/test_embedding_models.py b/tests/test_embedding_models.py index 6162353..5353006 100644 --- a/tests/test_embedding_models.py +++ b/tests/test_embedding_models.py @@ -26,7 +26,8 @@ def test_azureopenai_embeddings_raw(openai_embedding_call): ) output = model("Hello world") assert isinstance(output, list) - assert isinstance(output[0], float) + assert isinstance(output[0], list) + assert isinstance(output[0][0], float) openai_embedding_call.assert_called() @@ -53,8 +54,8 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call): side_effect=lambda *args, **kwargs: None, ) @patch( - "langchain.embeddings.huggingface.HuggingFaceBgeEmbeddings.embed_query", - side_effect=lambda *args, **kwargs: [1.0, 2.1, 3.2], + "langchain.embeddings.huggingface.HuggingFaceBgeEmbeddings.embed_documents", + side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], ) def test_huggingface_embddings( langchain_huggingface_embedding_call, sentence_transformers_init @@ -67,21 +68,23 @@ def test_huggingface_embddings( output = model("Hello World") assert isinstance(output, list) - assert isinstance(output[0], float) + assert isinstance(output[0], list) + assert isinstance(output[0][0], float) sentence_transformers_init.assert_called() langchain_huggingface_embedding_call.assert_called() @patch( - "langchain.embeddings.cohere.CohereEmbeddings.embed_query", - side_effect=lambda *args, **kwargs: [1.0, 2.1, 3.2], + "langchain.embeddings.cohere.CohereEmbeddings.embed_documents", + side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], ) -def test_cohere_embddings(langchain_cohere_embedding_call): +def test_cohere_embeddings(langchain_cohere_embedding_call): model = CohereEmbdeddings( model="embed-english-light-v2.0", cohere_api_key="my-api-key" ) output = model("Hello World") assert isinstance(output, list) - assert isinstance(output[0], float) + assert isinstance(output[0], list) + assert isinstance(output[0][0], float) langchain_cohere_embedding_call.assert_called() diff --git a/tests/test_indexing_retrieval.py b/tests/test_indexing_retrieval.py index fd4dd74..79f3bd7 100644 --- a/tests/test_indexing_retrieval.py +++ b/tests/test_indexing_retrieval.py @@ -60,7 +60,8 @@ def test_retrieving(mock_openai_embedding, tmp_path): ) index_pipeline(text=Document(text="Hello world")) - output = retrieval_pipeline(text=["Hello world", "Hello world"]) + output = retrieval_pipeline(text="Hello world") + output1 = retrieval_pipeline(text="Hello world") - assert len(output) == 2, "Expect 2 results" - assert output[0] == output[1], "Expect identical results" + assert len(output) == 1, "Expect 1 results" + assert output == output1, "Expect identical results" diff --git a/tests/test_llms_chat_models.py b/tests/test_llms_chat_models.py index 7ce4ec3..932447b 100644 --- a/tests/test_llms_chat_models.py +++ b/tests/test_llms_chat_models.py @@ -54,12 +54,6 @@ def test_azureopenai_model(openai_completion): ), "Output for single text is not LLMInterface" openai_completion.assert_called() - # test for list[str] input - batch mode - output = model(["hello world"]) - assert isinstance(output, list), "Output for batch string is not a list" - assert isinstance(output[0], LLMInterface), "Output for text is not LLMInterface" - openai_completion.assert_called() - # test for list[message] input - stream mode messages = [ SystemMessage(content="You are a philosohper"), @@ -73,9 +67,3 @@ def test_azureopenai_model(openai_completion): output, LLMInterface ), "Output for single text is not LLMInterface" openai_completion.assert_called() - - # test for list[list[message]] input - batch mode - output = model([messages]) - assert isinstance(output, list), "Output for batch string is not a list" - assert isinstance(output[0], LLMInterface), "Output for text is not LLMInterface" - openai_completion.assert_called() diff --git a/tests/test_llms_completion_models.py b/tests/test_llms_completion_models.py index a10d1f5..ef001a5 100644 --- a/tests/test_llms_completion_models.py +++ b/tests/test_llms_completion_models.py @@ -44,11 +44,6 @@ def test_azureopenai_model(openai_completion): model.agent, AzureOpenAILC ), "Agent not wrapped in Langchain's AzureOpenAI" - output = model(["hello world"]) - assert isinstance(output, list), "Output for batch is not a list" - assert isinstance(output[0], LLMInterface), "Output for text is not LLMInterface" - openai_completion.assert_called() - output = model("hello world") assert isinstance( output, LLMInterface @@ -72,11 +67,6 @@ def test_openai_model(openai_completion): model.agent, OpenAILC ), "Agent is not wrapped in Langchain's OpenAI" - output = model(["hello world"]) - assert isinstance(output, list), "Output for batch is not a list" - assert isinstance(output[0], LLMInterface), "Output for text is not LLMInterface" - openai_completion.assert_called() - output = model("hello world") assert isinstance( output, LLMInterface diff --git a/tests/test_post_processing.py b/tests/test_post_processing.py index aefa128..8f14384 100644 --- a/tests/test_post_processing.py +++ b/tests/test_post_processing.py @@ -13,23 +13,13 @@ def regex_extractor(): def test_run_document(regex_extractor): document = Document(text="This is a test. 1 2 3") - extracted_document = regex_extractor(document) + extracted_document = regex_extractor(document)[0] assert extracted_document.text == "One" assert extracted_document.matches == ["One", "Two", "Three"] -def test_is_document(regex_extractor): - assert regex_extractor.is_document(Document(text="Test")) - assert not regex_extractor.is_document("Test") - - -def test_is_batch(regex_extractor): - assert regex_extractor.is_batch([Document(text="Test")]) - assert not regex_extractor.is_batch(Document(text="Test")) - - def test_run_raw(regex_extractor): - output = regex_extractor("This is a test. 123") + output = regex_extractor("This is a test. 123")[0] assert output.text == "123" assert output.matches == ["123"] diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 51c5154..6eb73c3 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -54,7 +54,7 @@ def test_run(): result = prompt() - assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = One" + assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One']" def test_set_method():