Enforce all IO objects to be subclassed from Document (#88)

* enforce Document as IO

* Separate rerankers, splitters and extractors (#85)

* partially refractor importing

* add text to embedding outputs

---------

Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2023-11-27 16:35:09 +07:00
committed by GitHub
parent 2186c5558f
commit 8e0779a22d
13 changed files with 108 additions and 59 deletions

View File

@@ -1,4 +1,25 @@
from .component import BaseComponent
from .schema import Document
from .schema import (
AIMessage,
BaseMessage,
Document,
DocumentWithEmbedding,
ExtractorOutput,
HumanMessage,
LLMInterface,
RetrievedDocument,
SystemMessage,
)
__all__ = ["BaseComponent", "Document"]
__all__ = [
"BaseComponent",
"Document",
"DocumentWithEmbedding",
"BaseMessage",
"SystemMessage",
"AIMessage",
"HumanMessage",
"RetrievedDocument",
"LLMInterface",
"ExtractorOutput",
]

View File

@@ -2,6 +2,8 @@ from abc import abstractmethod
from theflow.base import Function
from kotaemon.base.schema import Document
class BaseComponent(Function):
"""A component is a class that can be used to compose a pipeline
@@ -30,7 +32,6 @@ class BaseComponent(Function):
return self.__call__(self.inflow.flow())
@abstractmethod
def run(self, *args, **kwargs):
# enforce output type to be compatible with Document
def run(self, *args, **kwargs) -> Document | list[Document] | None:
"""Run the component."""
...

View File

@@ -32,6 +32,8 @@ class Document(BaseDocument):
kwargs["content"] = kwargs["text"]
elif kwargs.get("embedding", None) is not None:
kwargs["content"] = kwargs["embedding"]
# default text indicating this document only contains embedding
kwargs["text"] = "<EMBEDDING>"
elif isinstance(content, Document):
kwargs = content.dict()
else:
@@ -65,6 +67,17 @@ class Document(BaseDocument):
return str(self.content)
class DocumentWithEmbedding(Document):
"""Subclass of Document which must contains embedding
Use this if you want to enforce component's IOs to must contain embedding.
"""
def __init__(self, embedding: list[float], *args, **kwargs):
kwargs["embedding"] = embedding
super().__init__(*args, **kwargs)
class BaseMessage(Document):
def __add__(self, other: Any):
raise NotImplementedError