Enforce all IO objects to be subclassed from Document (#88)

* enforce Document as IO

* Separate rerankers, splitters and extractors (#85)

* partially refractor importing

* add text to embedding outputs

---------

Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2023-11-27 16:35:09 +07:00
committed by GitHub
parent 2186c5558f
commit 8e0779a22d
13 changed files with 108 additions and 59 deletions

View File

@@ -3,7 +3,7 @@ from typing import Callable, List
from theflow import Function, Node, Param
from kotaemon.base import BaseComponent
from kotaemon.base import BaseComponent, Document
from kotaemon.llms import LLM, BasePromptComponent
from kotaemon.llms.chats.openai import AzureChatOpenAI
@@ -65,15 +65,19 @@ class Thought(BaseComponent):
"""
prompt: str = Param(
help="The prompt template string. This prompt template has Python-like "
"variable placeholders, that then will be subsituted with real values when "
"this component is executed"
help=(
"The prompt template string. This prompt template has Python-like "
"variable placeholders, that then will be subsituted with real values when "
"this component is executed"
)
)
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
post_process: Function = Node(
help="The function post-processor that post-processes LLM output prediction ."
"It should take a string as input (this is the LLM output text) and return "
"a dictionary, where the key should"
help=(
"The function post-processor that post-processes LLM output prediction ."
"It should take a string as input (this is the LLM output text) and return "
"a dictionary, where the key should"
)
)
@Node.auto(depends_on="prompt")
@@ -81,11 +85,13 @@ class Thought(BaseComponent):
"""Automatically wrap around param prompt. Can ignore"""
return BasePromptComponent(self.prompt)
def run(self, **kwargs) -> dict:
def run(self, **kwargs) -> Document:
"""Run the chain of thought"""
prompt = self.prompt_template(**kwargs).text
response = self.llm(prompt).text
return self.post_process(response)
response = self.post_process(response)
return Document(response)
def get_variables(self) -> List[str]:
return []
@@ -146,7 +152,7 @@ class ManualSequentialChainOfThought(BaseComponent):
help="Callback on terminate condition. Default to always return False",
)
def run(self, **kwargs) -> dict:
def run(self, **kwargs) -> Document:
"""Run the manual chain of thought"""
inputs = deepcopy(kwargs)
@@ -156,11 +162,11 @@ class ManualSequentialChainOfThought(BaseComponent):
self._prepare_child(thought, f"thought{idx}")
output = thought(**inputs)
inputs.update(output)
inputs.update(output.content)
if self.terminate(inputs):
break
return inputs
return Document(inputs)
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
return ManualSequentialChainOfThought(