From 0dede9c82d212a27718919d1ffc28276c1db4da3 Mon Sep 17 00:00:00 2001 From: ian_Cin Date: Mon, 27 Nov 2023 10:38:19 +0700 Subject: [PATCH] Subclass chat messages from Document (#86) --- knowledgehub/base/component.py | 1 + knowledgehub/base/schema.py | 25 ++++++++++++++++++++++--- knowledgehub/chatbot/base.py | 3 ++- knowledgehub/llms/__init__.py | 4 ++-- knowledgehub/llms/chats/base.py | 3 ++- knowledgehub/pipelines/citation.py | 2 +- tests/test_llms_chat_models.py | 8 ++++++-- 7 files changed, 36 insertions(+), 10 deletions(-) diff --git a/knowledgehub/base/component.py b/knowledgehub/base/component.py index 132d597..5ae1fde 100644 --- a/knowledgehub/base/component.py +++ b/knowledgehub/base/component.py @@ -31,5 +31,6 @@ class BaseComponent(Function): @abstractmethod def run(self, *args, **kwargs): + # enforce output type to be compatible with Document """Run the component.""" ... diff --git a/knowledgehub/base/schema.py b/knowledgehub/base/schema.py index 3da62b3..9a4ced7 100644 --- a/knowledgehub/base/schema.py +++ b/knowledgehub/base/schema.py @@ -2,7 +2,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Optional, TypeVar -from langchain.schema.messages import AIMessage +from langchain.schema.messages import AIMessage as LCAIMessage +from langchain.schema.messages import HumanMessage as LCHumanMessage +from langchain.schema.messages import SystemMessage as LCSystemMessage from llama_index.bridge.pydantic import Field from llama_index.schema import Document as BaseDocument @@ -63,11 +65,28 @@ class Document(BaseDocument): return str(self.content) +class BaseMessage(Document): + def __add__(self, other: Any): + raise NotImplementedError + + +class SystemMessage(BaseMessage, LCSystemMessage): + pass + + +class AIMessage(BaseMessage, LCAIMessage): + pass + + +class HumanMessage(BaseMessage, LCHumanMessage): + pass + + class RetrievedDocument(Document): """Subclass of Document with retrieval-related information Attributes: - score (float): score of the document (from 0.0 to 1.0) + score (float): score of the document (from 0.0 to 1.0) retrieval_metadata (dict): metadata from the retrieval process, can be used by different components in a retrieved pipeline to communicate with each other @@ -77,7 +96,7 @@ class RetrievedDocument(Document): retrieval_metadata: dict = Field(default={}) -class LLMInterface(Document): +class LLMInterface(AIMessage): candidates: list[str] = Field(default_factory=list) completion_tokens: int = -1 total_tokens: int = -1 diff --git a/knowledgehub/chatbot/base.py b/knowledgehub/chatbot/base.py index 846f26d..3c305c0 100644 --- a/knowledgehub/chatbot/base.py +++ b/knowledgehub/chatbot/base.py @@ -1,9 +1,10 @@ from abc import abstractmethod from typing import List, Optional -from langchain.schema.messages import AIMessage, SystemMessage from theflow import SessionFunction +from kotaemon.base.schema import AIMessage, SystemMessage + from ..base import BaseComponent from ..base.schema import LLMInterface from ..llms.chats.base import BaseMessage, HumanMessage diff --git a/knowledgehub/llms/__init__.py b/knowledgehub/llms/__init__.py index cbf2860..5bb7c4f 100644 --- a/knowledgehub/llms/__init__.py +++ b/knowledgehub/llms/__init__.py @@ -1,7 +1,7 @@ -from langchain.schema.messages import AIMessage, SystemMessage +from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage from .branching import GatedBranchingPipeline, SimpleBranchingPipeline -from .chats import AzureChatOpenAI, BaseMessage, ChatLLM, HumanMessage +from .chats import AzureChatOpenAI, ChatLLM from .completions import LLM, AzureOpenAI, OpenAI from .linear import GatedLinearPipeline, SimpleLinearPipeline from .prompts import BasePromptComponent, PromptTemplate diff --git a/knowledgehub/llms/chats/base.py b/knowledgehub/llms/chats/base.py index 664546f..7fdcd62 100644 --- a/knowledgehub/llms/chats/base.py +++ b/knowledgehub/llms/chats/base.py @@ -4,9 +4,10 @@ import logging from typing import Type from langchain.chat_models.base import BaseChatModel -from langchain.schema.messages import BaseMessage, HumanMessage from theflow.base import Param +from kotaemon.base.schema import BaseMessage, HumanMessage + from ...base import BaseComponent from ...base.schema import LLMInterface diff --git a/knowledgehub/pipelines/citation.py b/knowledgehub/pipelines/citation.py index 2577360..3f4f295 100644 --- a/knowledgehub/pipelines/citation.py +++ b/knowledgehub/pipelines/citation.py @@ -1,9 +1,9 @@ from typing import Iterator, List, Union -from langchain.schema.messages import HumanMessage, SystemMessage from pydantic import BaseModel, Field from kotaemon.base import BaseComponent +from kotaemon.base.schema import HumanMessage, SystemMessage from ..llms.chats.base import ChatLLM from ..llms.completions.base import LLM diff --git a/tests/test_llms_chat_models.py b/tests/test_llms_chat_models.py index fd3c4e4..90b2ab5 100644 --- a/tests/test_llms_chat_models.py +++ b/tests/test_llms_chat_models.py @@ -1,10 +1,14 @@ from unittest.mock import patch from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC -from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage from openai.types.chat.chat_completion import ChatCompletion -from kotaemon.base.schema import LLMInterface +from kotaemon.base.schema import ( + AIMessage, + HumanMessage, + LLMInterface, + SystemMessage, +) from kotaemon.llms.chats.openai import AzureChatOpenAI _openai_chat_completion_response = ChatCompletion.parse_obj(