diff --git a/knowledgehub/agents/langchain.py b/knowledgehub/agents/langchain.py index c7f4b43..c36f7b7 100644 --- a/knowledgehub/agents/langchain.py +++ b/knowledgehub/agents/langchain.py @@ -54,7 +54,7 @@ class LangchainAgent(BaseAgent): # reinit Langchain AgentExecutor self.agent = initialize_agent( langchain_plugins, - self.llm.agent, + self.llm._obj, agent=self.AGENT_TYPE_MAP[self.agent_type], handle_parsing_errors=True, verbose=True, diff --git a/knowledgehub/agents/tools/llm.py b/knowledgehub/agents/tools/llm.py index 948b8ff..25e9440 100644 --- a/knowledgehub/agents/tools/llm.py +++ b/knowledgehub/agents/tools/llm.py @@ -21,7 +21,7 @@ class LLMTool(BaseTool): "are confident in solving the problem " "yourself. Input can be any instruction." ) - llm: BaseLLM = AzureChatOpenAI() + llm: BaseLLM = AzureChatOpenAI.withx() args_schema: Optional[Type[BaseModel]] = LLMArgs def _run_tool(self, query: AnyStr) -> str: diff --git a/knowledgehub/embeddings/__init__.py b/knowledgehub/embeddings/__init__.py index 4d32d13..1b8b452 100644 --- a/knowledgehub/embeddings/__init__.py +++ b/knowledgehub/embeddings/__init__.py @@ -1,4 +1,15 @@ from .base import BaseEmbeddings -from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings +from .langchain_based import ( + AzureOpenAIEmbeddings, + CohereEmbdeddings, + HuggingFaceEmbeddings, + OpenAIEmbeddings, +) -__all__ = ["BaseEmbeddings", "OpenAIEmbeddings", "AzureOpenAIEmbeddings"] +__all__ = [ + "BaseEmbeddings", + "OpenAIEmbeddings", + "AzureOpenAIEmbeddings", + "CohereEmbdeddings", + "HuggingFaceEmbeddings", +] diff --git a/knowledgehub/embeddings/base.py b/knowledgehub/embeddings/base.py index 632a480..220b47a 100644 --- a/knowledgehub/embeddings/base.py +++ b/knowledgehub/embeddings/base.py @@ -1,10 +1,6 @@ from __future__ import annotations from abc import abstractmethod -from typing import Type - -from langchain.schema.embeddings import Embeddings as LCEmbeddings -from theflow import Param from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding @@ -15,52 +11,3 @@ class BaseEmbeddings(BaseComponent): self, text: str | list[str] | Document | list[Document] ) -> list[DocumentWithEmbedding]: ... - - -class LangchainEmbeddings(BaseEmbeddings): - _lc_class: Type[LCEmbeddings] - - def __init__(self, **params): - if self._lc_class is None: - raise AttributeError( - "Should set _lc_class attribute to the LLM class from Langchain " - "if using LLM from Langchain" - ) - - self._kwargs: dict = {} - for param in list(params.keys()): - if param in self._lc_class.__fields__: # type: ignore - self._kwargs[param] = params.pop(param) - super().__init__(**params) - - def __setattr__(self, name, value): - if name in self._lc_class.__fields__: - self._kwargs[name] = value - else: - super().__setattr__(name, value) - - @Param.auto(cache=False) - def agent(self): - return self._lc_class(**self._kwargs) - - def run(self, text): - input_: list[str] = [] - if not isinstance(text, list): - text = [text] - - for item in text: - if isinstance(item, str): - input_.append(item) - elif isinstance(item, Document): - input_.append(item.text) - else: - raise ValueError( - f"Invalid input type {type(item)}, should be str or Document" - ) - - embeddings = self.agent.embed_documents(input_) - - return [ - DocumentWithEmbedding(text=each_text, embedding=each_embedding) - for each_text, each_embedding in zip(input_, embeddings) - ] diff --git a/knowledgehub/embeddings/cohere.py b/knowledgehub/embeddings/cohere.py deleted file mode 100644 index e161787..0000000 --- a/knowledgehub/embeddings/cohere.py +++ /dev/null @@ -1,12 +0,0 @@ -from langchain.embeddings import CohereEmbeddings as LCCohereEmbeddings - -from kotaemon.embeddings.base import LangchainEmbeddings - - -class CohereEmbdeddings(LangchainEmbeddings): - """Cohere embeddings. - - This class wraps around the Langchain CohereEmbeddings class. - """ - - _lc_class = LCCohereEmbeddings diff --git a/knowledgehub/embeddings/huggingface.py b/knowledgehub/embeddings/huggingface.py deleted file mode 100644 index 37ba5e9..0000000 --- a/knowledgehub/embeddings/huggingface.py +++ /dev/null @@ -1,12 +0,0 @@ -from langchain.embeddings import HuggingFaceBgeEmbeddings as LCHuggingFaceEmbeddings - -from kotaemon.embeddings.base import LangchainEmbeddings - - -class HuggingFaceEmbeddings(LangchainEmbeddings): - """HuggingFace embeddings - - This class wraps around the Langchain HuggingFaceEmbeddings class - """ - - _lc_class = LCHuggingFaceEmbeddings diff --git a/knowledgehub/embeddings/langchain_based.py b/knowledgehub/embeddings/langchain_based.py new file mode 100644 index 0000000..98090fe --- /dev/null +++ b/knowledgehub/embeddings/langchain_based.py @@ -0,0 +1,194 @@ +from typing import Optional + +from kotaemon.base import Document, DocumentWithEmbedding + +from .base import BaseEmbeddings + + +class LCEmbeddingMixin: + def _get_lc_class(self): + raise NotImplementedError( + "Please return the relevant Langchain class in in _get_lc_class" + ) + + def __init__(self, **params): + self._lc_class = self._get_lc_class() + self._obj = self._lc_class(**params) + self._kwargs: dict = params + + super().__init__() + + def run(self, text): + input_: list[str] = [] + if not isinstance(text, list): + text = [text] + + for item in text: + if isinstance(item, str): + input_.append(item) + elif isinstance(item, Document): + input_.append(item.text) + else: + raise ValueError( + f"Invalid input type {type(item)}, should be str or Document" + ) + + embeddings = self._obj.embed_documents(input_) + + return [ + DocumentWithEmbedding(text=each_text, embedding=each_embedding) + for each_text, each_embedding in zip(input_, embeddings) + ] + + def __repr__(self): + kwargs = [] + for key, value_obj in self._kwargs.items(): + value = repr(value_obj) + kwargs.append(f"{key}={value}") + kwargs_repr = ", ".join(kwargs) + return f"{self.__class__.__name__}({kwargs_repr})" + + def __str__(self): + kwargs = [] + for key, value_obj in self._kwargs.items(): + value = str(value_obj) + if len(value) > 20: + value = f"{value[:15]}..." + kwargs.append(f"{key}={value}") + kwargs_repr = ", ".join(kwargs) + return f"{self.__class__.__name__}({kwargs_repr})" + + def __setattr__(self, name, value): + if name == "_lc_class": + return super().__setattr__(name, value) + + if name in self._lc_class.__fields__: + self._kwargs[name] = value + self._obj = self._lc_class(**self._kwargs) + else: + super().__setattr__(name, value) + + def __getattr__(self, name): + if name in self._kwargs: + return self._kwargs[name] + return getattr(self._obj, name) + + def dump(self): + return { + "__type__": f"{self.__module__}.{self.__class__.__qualname__}", + **self._kwargs, + } + + def specs(self, path: str): + path = path.strip(".") + if "." in path: + raise ValueError("path should not contain '.'") + + if path in self._lc_class.__fields__: + return { + "__type__": "theflow.base.ParamAttr", + "refresh_on_set": True, + "strict_type": True, + } + + raise ValueError(f"Invalid param {path}") + + +class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings): + """Wrapper around Langchain's OpenAI embedding, focusing on key parameters""" + + def __init__( + self, + model: str = "text-embedding-ada-002", + openai_api_version: Optional[str] = None, + openai_api_base: Optional[str] = None, + openai_api_type: Optional[str] = None, + openai_api_key: Optional[str] = None, + request_timeout: Optional[float] = None, + **params, + ): + super().__init__( + model=model, + openai_api_version=openai_api_version, + openai_api_base=openai_api_base, + openai_api_type=openai_api_type, + openai_api_key=openai_api_key, + request_timeout=request_timeout, + **params, + ) + + def _get_lc_class(self): + import langchain.embeddings + + return langchain.emebddings.OpenAIEmbeddings + + +class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings): + """Wrapper around Langchain's AzureOpenAI embedding, focusing on key parameters""" + + def __init__( + self, + azure_endpoint: Optional[str] = None, + deployment: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api_version: Optional[str] = None, + request_timeout: Optional[float] = None, + **params, + ): + super().__init__( + azure_endpoint=azure_endpoint, + deployment=deployment, + openai_api_version=openai_api_version, + openai_api_key=openai_api_key, + request_timeout=request_timeout, + **params, + ) + + def _get_lc_class(self): + import langchain.embeddings + + return langchain.embeddings.AzureOpenAIEmbeddings + + +class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings): + """Wrapper around Langchain's Cohere embedding, focusing on key parameters""" + + def __init__( + self, + model: str = "embed-english-v2.0", + cohere_api_key: Optional[str] = None, + truncate: Optional[str] = None, + request_timeout: Optional[float] = None, + **params, + ): + super().__init__( + model=model, + cohere_api_key=cohere_api_key, + truncate=truncate, + request_timeout=request_timeout, + **params, + ) + + def _get_lc_class(self): + import langchain.embeddings + + return langchain.embeddings.CohereEmbeddings + + +class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings): + """Wrapper around Langchain's HuggingFace embedding, focusing on key parameters""" + + def __init__( + self, + model_name: str = "sentence-transformers/all-mpnet-base-v2", + **params, + ): + super().__init__( + model_name=model_name, + **params, + ) + + def _get_lc_class(self): + import langchain.embeddings + + return langchain.embeddings.HuggingFaceBgeEmbeddings diff --git a/knowledgehub/embeddings/openai.py b/knowledgehub/embeddings/openai.py deleted file mode 100644 index da02755..0000000 --- a/knowledgehub/embeddings/openai.py +++ /dev/null @@ -1,21 +0,0 @@ -from langchain import embeddings as lcembeddings - -from .base import LangchainEmbeddings - - -class OpenAIEmbeddings(LangchainEmbeddings): - """OpenAI embeddings. - - This method is wrapped around the Langchain OpenAIEmbeddings class. - """ - - _lc_class = lcembeddings.OpenAIEmbeddings - - -class AzureOpenAIEmbeddings(LangchainEmbeddings): - """Azure OpenAI embeddings. - - This method is wrapped around the Langchain AzureOpenAIEmbeddings class. - """ - - _lc_class = lcembeddings.AzureOpenAIEmbeddings diff --git a/knowledgehub/indices/base.py b/knowledgehub/indices/base.py index 6fee77a..8843aad 100644 --- a/knowledgehub/indices/base.py +++ b/knowledgehub/indices/base.py @@ -46,20 +46,48 @@ class LlamaIndexDocTransformerMixin: "Please return the relevant LlamaIndex class in _get_li_class" ) - def __init__(self, *args, **kwargs): - _li_cls = self._get_li_class() - self._obj = _li_cls(*args, **kwargs) + def __init__(self, **params): + self._li_cls = self._get_li_class() + self._obj = self._li_cls(**params) + self._kwargs = params super().__init__() + def __repr__(self): + kwargs = [] + for key, value_obj in self._kwargs.items(): + value = repr(value_obj) + kwargs.append(f"{key}={value}") + kwargs_repr = ", ".join(kwargs) + return f"{self.__class__.__name__}({kwargs_repr})" + + def __str__(self): + kwargs = [] + for key, value_obj in self._kwargs.items(): + value = str(value_obj) + if len(value) > 20: + value = f"{value[:15]}..." + kwargs.append(f"{key}={value}") + kwargs_repr = ", ".join(kwargs) + return f"{self.__class__.__name__}({kwargs_repr})" + def __setattr__(self, name: str, value: Any) -> None: if name.startswith("_") or name in self._protected_keywords(): return super().__setattr__(name, value) + self._kwargs[name] = value return setattr(self._obj, name, value) def __getattr__(self, name: str) -> Any: + if name in self._kwargs: + return self._kwargs[name] return getattr(self._obj, name) + def dump(self): + return { + "__type__": f"{self.__module__}.{self.__class__.__qualname__}", + **self._kwargs, + } + def run( self, documents: list[Document], diff --git a/knowledgehub/indices/extractors/doc_parsers.py b/knowledgehub/indices/extractors/doc_parsers.py index ed12fdd..7dad528 100644 --- a/knowledgehub/indices/extractors/doc_parsers.py +++ b/knowledgehub/indices/extractors/doc_parsers.py @@ -6,6 +6,14 @@ class BaseDocParser(DocTransformer): class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser): + def __init__( + self, + llm=None, + nodes: int = 5, + **params, + ): + super().__init__(llm=llm, nodes=nodes, **params) + def _get_li_class(self): from llama_index.extractors import TitleExtractor @@ -13,6 +21,14 @@ class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser): class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser): + def __init__( + self, + llm=None, + summaries: list[str] = ["self"], + **params, + ): + super().__init__(llm=llm, summaries=summaries, **params) + def _get_li_class(self): from llama_index.extractors import SummaryExtractor diff --git a/knowledgehub/indices/rankings/base.py b/knowledgehub/indices/rankings/base.py index 9515199..c9c0b9b 100644 --- a/knowledgehub/indices/rankings/base.py +++ b/knowledgehub/indices/rankings/base.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import abstractmethod -from ...base import BaseComponent, Document +from kotaemon.base import BaseComponent, Document class BaseReranking(BaseComponent): diff --git a/knowledgehub/indices/rankings/cohere.py b/knowledgehub/indices/rankings/cohere.py index fce039b..1f9c32a 100644 --- a/knowledgehub/indices/rankings/cohere.py +++ b/knowledgehub/indices/rankings/cohere.py @@ -2,7 +2,8 @@ from __future__ import annotations import os -from ...base import Document +from kotaemon.base import Document + from .base import BaseReranking diff --git a/knowledgehub/indices/rankings/llm.py b/knowledgehub/indices/rankings/llm.py index f36eca5..bff81ff 100644 --- a/knowledgehub/indices/rankings/llm.py +++ b/knowledgehub/indices/rankings/llm.py @@ -1,17 +1,13 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -from typing import Union from langchain.output_parsers.boolean import BooleanOutputParser -from ...base import Document -from ...llms import PromptTemplate -from ...llms.chats.base import ChatLLM -from ...llms.completions.base import LLM -from .base import BaseReranking +from kotaemon.base import Document +from kotaemon.llms import BaseLLM, PromptTemplate -BaseLLM = Union[ChatLLM, LLM] +from .base import BaseReranking RERANK_PROMPT_TEMPLATE = """Given the following question and context, return YES if the context is relevant to the question and NO if it isn't. diff --git a/knowledgehub/indices/splitters/__init__.py b/knowledgehub/indices/splitters/__init__.py index ea5c928..0c71a41 100644 --- a/knowledgehub/indices/splitters/__init__.py +++ b/knowledgehub/indices/splitters/__init__.py @@ -8,6 +8,20 @@ class BaseSplitter(DocTransformer): class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter): + def __init__( + self, + chunk_size: int = 1024, + chunk_overlap: int = 20, + separator: str = " ", + **params, + ): + super().__init__( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separator=separator, + **params, + ) + def _get_li_class(self): from llama_index.text_splitter import TokenTextSplitter @@ -15,6 +29,9 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter): class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter): + def __init__(self, window_size: int = 3, **params): + super().__init__(window_size=window_size, **params) + def _get_li_class(self): from llama_index.node_parser import SentenceWindowNodeParser diff --git a/knowledgehub/llms/branching.py b/knowledgehub/llms/branching.py index ccbcfcb..07671f2 100644 --- a/knowledgehub/llms/branching.py +++ b/knowledgehub/llms/branching.py @@ -154,8 +154,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline): if __name__ == "__main__": import dotenv - from kotaemon.llms import BasePromptComponent - from kotaemon.llms.chats.openai import AzureChatOpenAI + from kotaemon.llms import AzureChatOpenAI, BasePromptComponent from kotaemon.parsers import RegexExtractor def identity(x): diff --git a/knowledgehub/llms/chats/__init__.py b/knowledgehub/llms/chats/__init__.py index d634222..9388c5f 100644 --- a/knowledgehub/llms/chats/__init__.py +++ b/knowledgehub/llms/chats/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseMessage, ChatLLM, HumanMessage -from .openai import AzureChatOpenAI +from .base import ChatLLM +from .langchain_based import AzureChatOpenAI -__all__ = ["ChatLLM", "AzureChatOpenAI", "BaseMessage", "HumanMessage"] +__all__ = ["ChatLLM", "AzureChatOpenAI"] diff --git a/knowledgehub/llms/chats/base.py b/knowledgehub/llms/chats/base.py index 0804dff..a5280f4 100644 --- a/knowledgehub/llms/chats/base.py +++ b/knowledgehub/llms/chats/base.py @@ -1,12 +1,8 @@ from __future__ import annotations import logging -from typing import Type -from langchain.chat_models.base import BaseChatModel -from theflow.base import Param - -from kotaemon.base import BaseComponent, BaseMessage, HumanMessage, LLMInterface +from kotaemon.base import BaseComponent logger = logging.getLogger(__name__) @@ -23,83 +19,3 @@ class ChatLLM(BaseComponent): text = self.inflow.flow().text return self.__call__(text) - - -class LangchainChatLLM(ChatLLM): - _lc_class: Type[BaseChatModel] - - def __init__(self, **params): - if self._lc_class is None: - raise AttributeError( - "Should set _lc_class attribute to the LLM class from Langchain " - "if using LLM from Langchain" - ) - - self._kwargs: dict = {} - for param in list(params.keys()): - if param in self._lc_class.__fields__: - self._kwargs[param] = params.pop(param) - super().__init__(**params) - - @Param.auto(cache=False) - def agent(self) -> BaseChatModel: - return self._lc_class(**self._kwargs) - - def run( - self, messages: str | BaseMessage | list[BaseMessage], **kwargs - ) -> LLMInterface: - """Generate response from messages - - Args: - messages: history of messages to generate response from - **kwargs: additional arguments to pass to the langchain chat model - - Returns: - LLMInterface: generated response - """ - input_: list[BaseMessage] = [] - - if isinstance(messages, str): - input_ = [HumanMessage(content=messages)] - elif isinstance(messages, BaseMessage): - input_ = [messages] - else: - input_ = messages - - pred = self.agent.generate(messages=[input_], **kwargs) - all_text = [each.text for each in pred.generations[0]] - all_messages = [each.message for each in pred.generations[0]] - - completion_tokens, total_tokens, prompt_tokens = 0, 0, 0 - try: - if pred.llm_output is not None: - completion_tokens = pred.llm_output["token_usage"]["completion_tokens"] - total_tokens = pred.llm_output["token_usage"]["total_tokens"] - prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"] - except Exception: - logger.warning( - f"Cannot get token usage from LLM output for {self._lc_class.__name__}" - ) - - return LLMInterface( - text=all_text[0] if len(all_text) > 0 else "", - candidates=all_text, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_tokens=prompt_tokens, - messages=all_messages, - logits=[], - ) - - def __setattr__(self, name, value): - if name in self._lc_class.__fields__: - self._kwargs[name] = value - setattr(self.agent, name, value) - else: - super().__setattr__(name, value) - - def __getattr__(self, name): - if name in self._lc_class.__fields__: - return getattr(self.agent, name) - - return super().__getattr__(name) # type: ignore diff --git a/knowledgehub/llms/chats/langchain_based.py b/knowledgehub/llms/chats/langchain_based.py new file mode 100644 index 0000000..ccade14 --- /dev/null +++ b/knowledgehub/llms/chats/langchain_based.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import logging + +from kotaemon.base import BaseMessage, HumanMessage, LLMInterface + +from .base import ChatLLM + +logger = logging.getLogger(__name__) + + +class LCChatMixin: + def _get_lc_class(self): + raise NotImplementedError( + "Please return the relevant Langchain class in in _get_lc_class" + ) + + def __init__(self, **params): + self._lc_class = self._get_lc_class() + self._obj = self._lc_class(**params) + self._kwargs: dict = params + + super().__init__() + + def run( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> LLMInterface: + """Generate response from messages + + Args: + messages: history of messages to generate response from + **kwargs: additional arguments to pass to the langchain chat model + + Returns: + LLMInterface: generated response + """ + input_: list[BaseMessage] = [] + + if isinstance(messages, str): + input_ = [HumanMessage(content=messages)] + elif isinstance(messages, BaseMessage): + input_ = [messages] + else: + input_ = messages + + pred = self._obj.generate(messages=[input_], **kwargs) + all_text = [each.text for each in pred.generations[0]] + all_messages = [each.message for each in pred.generations[0]] + + completion_tokens, total_tokens, prompt_tokens = 0, 0, 0 + try: + if pred.llm_output is not None: + completion_tokens = pred.llm_output["token_usage"]["completion_tokens"] + total_tokens = pred.llm_output["token_usage"]["total_tokens"] + prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"] + except Exception: + logger.warning( + f"Cannot get token usage from LLM output for {self._lc_class.__name__}" + ) + + return LLMInterface( + text=all_text[0] if len(all_text) > 0 else "", + candidates=all_text, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + messages=all_messages, + logits=[], + ) + + def __repr__(self): + kwargs = [] + for key, value_obj in self._kwargs.items(): + value = repr(value_obj) + kwargs.append(f"{key}={value}") + kwargs_repr = ", ".join(kwargs) + return f"{self.__class__.__name__}({kwargs_repr})" + + def __str__(self): + kwargs = [] + for key, value_obj in self._kwargs.items(): + value = str(value_obj) + if len(value) > 20: + value = f"{value[:15]}..." + kwargs.append(f"{key}={value}") + kwargs_repr = ", ".join(kwargs) + return f"{self.__class__.__name__}({kwargs_repr})" + + def __setattr__(self, name, value): + if name == "_lc_class": + return super().__setattr__(name, value) + + if name in self._lc_class.__fields__: + self._kwargs[name] = value + self._obj = self._lc_class(**self._kwargs) + else: + super().__setattr__(name, value) + + def __getattr__(self, name): + if name in self._kwargs: + return self._kwargs[name] + return getattr(self._obj, name) + + def dump(self): + return { + "__type__": f"{self.__module__}.{self.__class__.__qualname__}", + **self._kwargs, + } + + def specs(self, path: str): + path = path.strip(".") + if "." in path: + raise ValueError("path should not contain '.'") + + if path in self._lc_class.__fields__: + return { + "__type__": "theflow.base.ParamAttr", + "refresh_on_set": True, + "strict_type": True, + } + + raise ValueError(f"Invalid param {path}") + + +class AzureChatOpenAI(LCChatMixin, ChatLLM): + def __init__( + self, + azure_endpoint: str | None = None, + openai_api_key: str | None = None, + openai_api_version: str = "", + deployment_name: str | None = None, + temperature: float = 0.7, + request_timeout: float | None = None, + **params, + ): + super().__init__( + azure_endpoint=azure_endpoint, + openai_api_key=openai_api_key, + openai_api_version=openai_api_version, + deployment_name=deployment_name, + temperature=temperature, + request_timeout=request_timeout, + **params, + ) + + def _get_lc_class(self): + import langchain.chat_models + + return langchain.chat_models.AzureChatOpenAI diff --git a/knowledgehub/llms/chats/openai.py b/knowledgehub/llms/chats/openai.py deleted file mode 100644 index 2d0ec52..0000000 --- a/knowledgehub/llms/chats/openai.py +++ /dev/null @@ -1,7 +0,0 @@ -from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC - -from .base import LangchainChatLLM - - -class AzureChatOpenAI(LangchainChatLLM): - _lc_class = AzureChatOpenAILC diff --git a/knowledgehub/llms/completions/__init__.py b/knowledgehub/llms/completions/__init__.py index b980944..3ef6b8b 100644 --- a/knowledgehub/llms/completions/__init__.py +++ b/knowledgehub/llms/completions/__init__.py @@ -1,4 +1,4 @@ from .base import LLM -from .openai import AzureOpenAI, OpenAI +from .langchain_based import AzureOpenAI, OpenAI __all__ = ["LLM", "OpenAI", "AzureOpenAI"] diff --git a/knowledgehub/llms/completions/base.py b/knowledgehub/llms/completions/base.py index 15adf89..e4a8540 100644 --- a/knowledgehub/llms/completions/base.py +++ b/knowledgehub/llms/completions/base.py @@ -1,66 +1,5 @@ -import logging -from typing import Type - -from langchain.llms.base import BaseLLM -from theflow.base import Param - -from ...base import BaseComponent -from ...base.schema import LLMInterface - -logger = logging.getLogger(__name__) +from kotaemon.base import BaseComponent class LLM(BaseComponent): pass - - -class LangchainLLM(LLM): - _lc_class: Type[BaseLLM] - - def __init__(self, **params): - if self._lc_class is None: - raise AttributeError( - "Should set _lc_class attribute to the LLM class from Langchain " - "if using LLM from Langchain" - ) - - self._kwargs: dict = {} - for param in list(params.keys()): - if param in self._lc_class.__fields__: - self._kwargs[param] = params.pop(param) - super().__init__(**params) - - @Param.auto(cache=False) - def agent(self): - return self._lc_class(**self._kwargs) - - def run(self, text: str) -> LLMInterface: - pred = self.agent.generate([text]) - all_text = [each.text for each in pred.generations[0]] - - completion_tokens, total_tokens, prompt_tokens = 0, 0, 0 - try: - if pred.llm_output is not None: - completion_tokens = pred.llm_output["token_usage"]["completion_tokens"] - total_tokens = pred.llm_output["token_usage"]["total_tokens"] - prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"] - except Exception: - logger.warning( - f"Cannot get token usage from LLM output for {self._lc_class.__name__}" - ) - - return LLMInterface( - text=all_text[0] if len(all_text) > 0 else "", - candidates=all_text, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_tokens=prompt_tokens, - logits=[], - ) - - def __setattr__(self, name, value): - if name in self._lc_class.__fields__: - self._kwargs[name] = value - setattr(self.agent, name, value) - else: - super().__setattr__(name, value) diff --git a/knowledgehub/llms/completions/langchain_based.py b/knowledgehub/llms/completions/langchain_based.py new file mode 100644 index 0000000..a8d36ab --- /dev/null +++ b/knowledgehub/llms/completions/langchain_based.py @@ -0,0 +1,185 @@ +import logging +from typing import Optional + +from kotaemon.base import LLMInterface + +from .base import LLM + +logger = logging.getLogger(__name__) + + +class LCCompletionMixin: + def _get_lc_class(self): + raise NotImplementedError( + "Please return the relevant Langchain class in in _get_lc_class" + ) + + def __init__(self, **params): + self._lc_class = self._get_lc_class() + self._obj = self._lc_class(**params) + self._kwargs: dict = params + + super().__init__() + + def run(self, text: str) -> LLMInterface: + pred = self._obj.generate([text]) + all_text = [each.text for each in pred.generations[0]] + + completion_tokens, total_tokens, prompt_tokens = 0, 0, 0 + try: + if pred.llm_output is not None: + completion_tokens = pred.llm_output["token_usage"]["completion_tokens"] + total_tokens = pred.llm_output["token_usage"]["total_tokens"] + prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"] + except Exception: + logger.warning( + f"Cannot get token usage from LLM output for {self._lc_class.__name__}" + ) + + return LLMInterface( + text=all_text[0] if len(all_text) > 0 else "", + candidates=all_text, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + logits=[], + ) + + def __repr__(self): + kwargs = [] + for key, value_obj in self._kwargs.items(): + value = repr(value_obj) + kwargs.append(f"{key}={value}") + kwargs_repr = ", ".join(kwargs) + return f"{self.__class__.__name__}({kwargs_repr})" + + def __str__(self): + kwargs = [] + for key, value_obj in self._kwargs.items(): + value = str(value_obj) + if len(value) > 20: + value = f"{value[:15]}..." + kwargs.append(f"{key}={value}") + kwargs_repr = ", ".join(kwargs) + return f"{self.__class__.__name__}({kwargs_repr})" + + def __setattr__(self, name, value): + if name == "_lc_class": + return super().__setattr__(name, value) + + if name in self._lc_class.__fields__: + self._kwargs[name] = value + self._obj = self._lc_class(**self._kwargs) + else: + super().__setattr__(name, value) + + def __getattr__(self, name): + if name in self._kwargs: + return self._kwargs[name] + return getattr(self._obj, name) + + def dump(self): + return { + "__type__": f"{self.__module__}.{self.__class__.__qualname__}", + **self._kwargs, + } + + def specs(self, path: str): + path = path.strip(".") + if "." in path: + raise ValueError("path should not contain '.'") + + if path in self._lc_class.__fields__: + return { + "__type__": "theflow.base.ParamAttr", + "refresh_on_set": True, + "strict_type": True, + } + + raise ValueError(f"Invalid param {path}") + + +class OpenAI(LCCompletionMixin, LLM): + """Wrapper around Langchain's OpenAI class, focusing on key parameters""" + + def __init__( + self, + openai_api_key: Optional[str] = None, + openai_api_base: Optional[str] = None, + model_name: str = "text-davinci-003", + temperature: float = 0.7, + max_token: int = 256, + top_p: float = 1, + frequency_penalty: float = 0, + n: int = 1, + best_of: int = 1, + request_timeout: Optional[float] = None, + max_retries: int = 2, + streaming: bool = False, + **params, + ): + super().__init__( + openai_api_key=openai_api_key, + openai_api_base=openai_api_base, + model_name=model_name, + temperature=temperature, + max_token=max_token, + top_p=top_p, + frequency_penalty=frequency_penalty, + n=n, + best_of=best_of, + request_timeout=request_timeout, + max_retries=max_retries, + streaming=streaming, + **params, + ) + + def _get_lc_class(self): + import langchain.llms as langchain_llms + + return langchain_llms.OpenAI + + +class AzureOpenAI(LCCompletionMixin, LLM): + """Wrapper around Langchain's AzureOpenAI class, focusing on key parameters""" + + def __init__( + self, + azure_endpoint: Optional[str] = None, + deployment_name: Optional[str] = None, + openai_api_version: str = "", + openai_api_key: Optional[str] = None, + model_name: str = "text-davinci-003", + temperature: float = 0.7, + max_token: int = 256, + top_p: float = 1, + frequency_penalty: float = 0, + n: int = 1, + best_of: int = 1, + request_timeout: Optional[float] = None, + max_retries: int = 2, + streaming: bool = False, + **params, + ): + super().__init__( + azure_endpoint=azure_endpoint, + deployment_name=deployment_name, + openai_api_version=openai_api_version, + openai_api_key=openai_api_key, + model_name=model_name, + temperature=temperature, + max_token=max_token, + top_p=top_p, + frequency_penalty=frequency_penalty, + n=n, + best_of=best_of, + request_timeout=request_timeout, + max_retries=max_retries, + streaming=streaming, + **params, + ) + + def _get_lc_class(self): + import langchain.llms as langchain_llms + + return langchain_llms.AzureOpenAI diff --git a/knowledgehub/llms/completions/openai.py b/knowledgehub/llms/completions/openai.py deleted file mode 100644 index 93a25ee..0000000 --- a/knowledgehub/llms/completions/openai.py +++ /dev/null @@ -1,15 +0,0 @@ -import langchain.llms as langchain_llms - -from .base import LangchainLLM - - -class OpenAI(LangchainLLM): - """Wrapper around Langchain's OpenAI class""" - - _lc_class = langchain_llms.OpenAI - - -class AzureOpenAI(LangchainLLM): - """Wrapper around Langchain's AzureOpenAI class""" - - _lc_class = langchain_llms.AzureOpenAI diff --git a/knowledgehub/llms/linear.py b/knowledgehub/llms/linear.py index df6cc87..4fd95e4 100644 --- a/knowledgehub/llms/linear.py +++ b/knowledgehub/llms/linear.py @@ -21,8 +21,7 @@ class SimpleLinearPipeline(BaseComponent): post-processor component or function. Example Usage: - from kotaemon.llms.chats.openai import AzureChatOpenAI - from kotaemon.llms import BasePromptComponent + from kotaemon.llms import AzureChatOpenAI, BasePromptComponent def identity(x): return x @@ -87,8 +86,7 @@ class GatedLinearPipeline(SimpleLinearPipeline): condition. Example Usage: - from kotaemon.llms.chats.openai import AzureChatOpenAI - from kotaemon.llms import BasePromptComponent + from kotaemon.llms import AzureChatOpenAI, BasePromptComponent from kotaemon.parsers import RegexExtractor def identity(x): diff --git a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py index 8b98b5d..15c6ac4 100644 --- a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py +++ b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py @@ -6,7 +6,7 @@ from theflow.utils.modules import ObjectInitDeclaration as _ from kotaemon.base import BaseComponent from kotaemon.embeddings import AzureOpenAIEmbeddings -from kotaemon.llms.completions.openai import AzureOpenAI +from kotaemon.llms import AzureOpenAI from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore diff --git a/tests/simple_pipeline.py b/tests/simple_pipeline.py index 1e79d0b..5e9ffdb 100644 --- a/tests/simple_pipeline.py +++ b/tests/simple_pipeline.py @@ -6,7 +6,7 @@ from theflow.utils.modules import ObjectInitDeclaration as _ from kotaemon.base import BaseComponent from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.indices import VectorRetrieval -from kotaemon.llms.completions.openai import AzureOpenAI +from kotaemon.llms import AzureOpenAI from kotaemon.storages import ChromaVectorStore diff --git a/tests/test_agent.py b/tests/test_agent.py index c7c3c1f..4b7ea15 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -8,7 +8,7 @@ from kotaemon.agents.langchain import LangchainAgent from kotaemon.agents.react import ReactAgent from kotaemon.agents.rewoo import RewooAgent from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool -from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.llms import AzureChatOpenAI FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!" @@ -195,7 +195,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search): langchain_plugins = [tool.to_langchain_format() for tool in plugins] agent = initialize_agent( langchain_plugins, - llm.agent, + llm._obj, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, ) diff --git a/tests/test_citation.py b/tests/test_citation.py index 69fc544..b92dd61 100644 --- a/tests/test_citation.py +++ b/tests/test_citation.py @@ -5,7 +5,7 @@ import pytest from openai.types.chat.chat_completion import ChatCompletion from kotaemon.indices.qa import CitationPipeline -from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.llms import AzureChatOpenAI function_output = '{\n "question": "What is the provided _example_ benefits?",\n "answer": [\n {\n "fact": "特約死亡保険金: 被保険者がこの特約の保険期間中に死亡したときに支払います。",\n "substring_quote": ["特約死亡保険金"]\n },\n {\n "fact": "特約特定疾病保険金: 被保険者がこの特約の保険期間中に特定の疾病(悪性新生物(がん)、急性心筋梗塞または脳卒中)により所定の状態に該当したときに支払います。",\n "substring_quote": ["特約特定疾病保険金"]\n },\n {\n "fact": "特約障害保険金: 被保険者がこの特約の保険期間中に傷害もしくは疾病により所定の身体障害の状態に該当したとき、または不慮の事故により所定の身体障害の状態に該当したときに支払います。",\n "substring_quote": ["特約障害保険金"]\n },\n {\n "fact": "特約介護保険金: 被保険者がこの特約の保険期間中に傷害または疾病により所定の要介護状態に該当したときに支払います。",\n "substring_quote": ["特約介護保険金"]\n }\n ]\n}' diff --git a/tests/test_embedding_models.py b/tests/test_embedding_models.py index d530964..f431c56 100644 --- a/tests/test_embedding_models.py +++ b/tests/test_embedding_models.py @@ -3,9 +3,11 @@ from pathlib import Path from unittest.mock import patch from kotaemon.base import Document -from kotaemon.embeddings.cohere import CohereEmbdeddings -from kotaemon.embeddings.huggingface import HuggingFaceEmbeddings -from kotaemon.embeddings.openai import AzureOpenAIEmbeddings +from kotaemon.embeddings import ( + AzureOpenAIEmbeddings, + CohereEmbdeddings, + HuggingFaceEmbeddings, +) with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f: openai_embedding_batch = json.load(f) @@ -60,7 +62,7 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call): "langchain.embeddings.huggingface.HuggingFaceBgeEmbeddings.embed_documents", side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], ) -def test_huggingface_embddings( +def test_huggingface_embeddings( langchain_huggingface_embedding_call, sentence_transformers_init ): model = HuggingFaceEmbeddings( diff --git a/tests/test_indexing_retrieval.py b/tests/test_indexing_retrieval.py index 367495f..17ce2fb 100644 --- a/tests/test_indexing_retrieval.py +++ b/tests/test_indexing_retrieval.py @@ -6,7 +6,7 @@ import pytest from openai.resources.embeddings import Embeddings from kotaemon.base import Document -from kotaemon.embeddings.openai import AzureOpenAIEmbeddings +from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore diff --git a/tests/test_llms_chat_models.py b/tests/test_llms_chat_models.py index 90b2ab5..54c94ce 100644 --- a/tests/test_llms_chat_models.py +++ b/tests/test_llms_chat_models.py @@ -9,7 +9,7 @@ from kotaemon.base.schema import ( LLMInterface, SystemMessage, ) -from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.llms import AzureChatOpenAI _openai_chat_completion_response = ChatCompletion.parse_obj( { @@ -48,7 +48,7 @@ def test_azureopenai_model(openai_completion): temperature=0, ) assert isinstance( - model.agent, AzureChatOpenAILC + model._obj, AzureChatOpenAILC ), "Agent not wrapped in Langchain's AzureChatOpenAI" # test for str input - stream mode diff --git a/tests/test_llms_completion_models.py b/tests/test_llms_completion_models.py index 04be9ba..e084935 100644 --- a/tests/test_llms_completion_models.py +++ b/tests/test_llms_completion_models.py @@ -5,7 +5,7 @@ from langchain.llms import OpenAI as OpenAILC from openai.types.completion import Completion from kotaemon.base.schema import LLMInterface -from kotaemon.llms.completions.openai import AzureOpenAI, OpenAI +from kotaemon.llms import AzureOpenAI, OpenAI _openai_completion_response = Completion.parse_obj( { @@ -41,7 +41,7 @@ def test_azureopenai_model(openai_completion): request_timeout=60, ) assert isinstance( - model.agent, AzureOpenAILC + model._obj, AzureOpenAILC ), "Agent not wrapped in Langchain's AzureOpenAI" output = model("hello world") @@ -64,7 +64,7 @@ def test_openai_model(openai_completion): request_timeout=60, ) assert isinstance( - model.agent, OpenAILC + model._obj, OpenAILC ), "Agent is not wrapped in Langchain's OpenAI" output = model("hello world") diff --git a/tests/test_reranking.py b/tests/test_reranking.py index b1addca..03f0b72 100644 --- a/tests/test_reranking.py +++ b/tests/test_reranking.py @@ -5,7 +5,7 @@ from openai.types.chat.chat_completion import ChatCompletion from kotaemon.base import Document from kotaemon.indices.rankings import LLMReranking -from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.llms import AzureChatOpenAI _openai_chat_completion_responses = [ ChatCompletion.parse_obj( diff --git a/tests/test_tools.py b/tests/test_tools.py index 42a270d..e6f940b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -6,7 +6,7 @@ from openai.resources.embeddings import Embeddings from kotaemon.agents.tools import ComponentTool, GoogleSearchTool, WikipediaTool from kotaemon.base import Document -from kotaemon.embeddings.openai import AzureOpenAIEmbeddings +from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.indices.vectorindex import VectorIndexing, VectorRetrieval from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore