Enforce all IO objects to be subclassed from Document (#88)
* enforce Document as IO * Separate rerankers, splitters and extractors (#85) * partially refractor importing * add text to embedding outputs --------- Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Callable, List
|
||||
|
||||
from theflow import Function, Node, Param
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
from kotaemon.llms import LLM, BasePromptComponent
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
|
||||
@@ -65,15 +65,19 @@ class Thought(BaseComponent):
|
||||
"""
|
||||
|
||||
prompt: str = Param(
|
||||
help="The prompt template string. This prompt template has Python-like "
|
||||
"variable placeholders, that then will be subsituted with real values when "
|
||||
"this component is executed"
|
||||
help=(
|
||||
"The prompt template string. This prompt template has Python-like "
|
||||
"variable placeholders, that then will be subsituted with real values when "
|
||||
"this component is executed"
|
||||
)
|
||||
)
|
||||
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
|
||||
post_process: Function = Node(
|
||||
help="The function post-processor that post-processes LLM output prediction ."
|
||||
"It should take a string as input (this is the LLM output text) and return "
|
||||
"a dictionary, where the key should"
|
||||
help=(
|
||||
"The function post-processor that post-processes LLM output prediction ."
|
||||
"It should take a string as input (this is the LLM output text) and return "
|
||||
"a dictionary, where the key should"
|
||||
)
|
||||
)
|
||||
|
||||
@Node.auto(depends_on="prompt")
|
||||
@@ -81,11 +85,13 @@ class Thought(BaseComponent):
|
||||
"""Automatically wrap around param prompt. Can ignore"""
|
||||
return BasePromptComponent(self.prompt)
|
||||
|
||||
def run(self, **kwargs) -> dict:
|
||||
def run(self, **kwargs) -> Document:
|
||||
"""Run the chain of thought"""
|
||||
prompt = self.prompt_template(**kwargs).text
|
||||
response = self.llm(prompt).text
|
||||
return self.post_process(response)
|
||||
response = self.post_process(response)
|
||||
|
||||
return Document(response)
|
||||
|
||||
def get_variables(self) -> List[str]:
|
||||
return []
|
||||
@@ -146,7 +152,7 @@ class ManualSequentialChainOfThought(BaseComponent):
|
||||
help="Callback on terminate condition. Default to always return False",
|
||||
)
|
||||
|
||||
def run(self, **kwargs) -> dict:
|
||||
def run(self, **kwargs) -> Document:
|
||||
"""Run the manual chain of thought"""
|
||||
|
||||
inputs = deepcopy(kwargs)
|
||||
@@ -156,11 +162,11 @@ class ManualSequentialChainOfThought(BaseComponent):
|
||||
self._prepare_child(thought, f"thought{idx}")
|
||||
|
||||
output = thought(**inputs)
|
||||
inputs.update(output)
|
||||
inputs.update(output.content)
|
||||
if self.terminate(inputs):
|
||||
break
|
||||
|
||||
return inputs
|
||||
return Document(inputs)
|
||||
|
||||
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
|
||||
return ManualSequentialChainOfThought(
|
||||
|
Reference in New Issue
Block a user