Feat/Add ReAct and ReWOO Reasoning Pipelines (#43)

* Add ReactAgentPipeline by wrapping the ReactAgent

* Implement stream processing for ReactAgentPipeline and RewooAgentPipeline

* Fix highlight_citation in Rewoo and remove highlight_citation from React

* Fix importing ktem.llms inside kotaemon

* fix: Change Rewoo::solver's output to LLMInterface instead of plain text

* Add more user_settings to the RewooAgentPipeline

* Fix LLMTool

* Add more user_settings to the ReactAgentPipeline

* Minor fix

* Stream the react agent immediately

* Yield the Rewoo progress to info panel

* Hide the agent in flowsettings

* Remove redundant comments

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Albert 2024-05-09 16:06:24 +07:00 committed by GitHub
parent ec11b54ff2
commit 466adf2d94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1114 additions and 25 deletions

View File

@ -253,5 +253,6 @@ class AgentOutput(LLMInterface):
text: str text: str
type: str = "agent" type: str = "agent"
agent_type: AgentType agent_type: AgentType
status: Literal["finished", "stopped", "failed"] status: Literal["thinking", "finished", "stopped", "failed"]
error: Optional[str] = None error: Optional[str] = None
intermediate_steps: Optional[list] = None

View File

@ -1,11 +1,15 @@
import logging import logging
import re import re
from functools import partial
from typing import Optional from typing import Optional
import tiktoken
from kotaemon.agents.base import BaseAgent, BaseLLM from kotaemon.agents.base import BaseAgent, BaseLLM
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
from kotaemon.agents.tools import BaseTool from kotaemon.agents.tools import BaseTool
from kotaemon.base import Param from kotaemon.base import Document, Param
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
FINAL_ANSWER_ACTION = "Final Answer:" FINAL_ANSWER_ACTION = "Final Answer:"
@ -22,6 +26,7 @@ class ReactAgent(BaseAgent):
description: str = "ReactAgent for answering multi-step reasoning questions" description: str = "ReactAgent for answering multi-step reasoning questions"
llm: BaseLLM llm: BaseLLM
prompt_template: Optional[PromptTemplate] = None prompt_template: Optional[PromptTemplate] = None
output_lang: str = "English"
plugins: list[BaseTool] = Param( plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="List of tools to be used in the agent. " default_callback=lambda _: [], help="List of tools to be used in the agent. "
) )
@ -32,8 +37,18 @@ class ReactAgent(BaseAgent):
default_callback=lambda _: [], default_callback=lambda _: [],
help="List of AgentAction and observation (tool) output", help="List of AgentAction and observation (tool) output",
) )
max_iterations: int = 10 max_iterations: int = 5
strict_decode: bool = False strict_decode: bool = False
trim_func: TokenSplitter = TokenSplitter.withx(
chunk_size=800,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
def _compose_plugin_description(self) -> str: def _compose_plugin_description(self) -> str:
""" """
@ -119,6 +134,7 @@ class ReactAgent(BaseAgent):
agent_scratchpad=agent_scratchpad, agent_scratchpad=agent_scratchpad,
tool_description=tool_description, tool_description=tool_description,
tool_names=tool_names, tool_names=tool_names,
lang=self.output_lang,
) )
def _format_function_map(self) -> dict[str, BaseTool]: def _format_function_map(self) -> dict[str, BaseTool]:
@ -133,6 +149,20 @@ class ReactAgent(BaseAgent):
function_map[plugin.name] = plugin function_map[plugin.name] = plugin
return function_map return function_map
def _trim(self, text: str) -> str:
"""
Trim the text to the maximum token length.
"""
if isinstance(text, str):
texts = self.trim_func([Document(text=text)])
elif isinstance(text, Document):
texts = self.trim_func([text])
else:
raise ValueError("Invalid text type to trim")
trim_text = texts[0].text
logging.info(f"len (trimmed): {len(trim_text)}")
return trim_text
def clear(self): def clear(self):
""" """
Clear and reset the agent. Clear and reset the agent.
@ -183,6 +213,11 @@ class ReactAgent(BaseAgent):
logging.info(f"Action: {action_name}") logging.info(f"Action: {action_name}")
logging.info(f"Tool Input: {tool_input}") logging.info(f"Tool Input: {tool_input}")
result = self._format_function_map()[action_name](tool_input) result = self._format_function_map()[action_name](tool_input)
# trim the worker output to 1000 tokens, as we are appending
# all workers' logs and it can exceed the token limit if we
# don't limit each. Fix this number regarding to the LLM capacity.
result = self._trim(result)
logging.info(f"Result: {result}") logging.info(f"Result: {result}")
self.intermediate_steps.append((action_step, result)) self.intermediate_steps.append((action_step, result))
@ -202,3 +237,100 @@ class ReactAgent(BaseAgent):
intermediate_steps=self.intermediate_steps, intermediate_steps=self.intermediate_steps,
max_iterations=max_iterations, max_iterations=max_iterations,
) )
def stream(self, instruction, max_iterations=None):
"""
Stream the agent with the given instruction.
Args:
instruction: Instruction to run the agent with.
max_iterations: Maximum number of iterations
of reasoning steps, defaults to 10.
Return:
AgentOutput object.
"""
if not max_iterations:
max_iterations = self.max_iterations
assert max_iterations > 0
self.clear()
logging.info(f"Running {self.name} with instruction: {instruction}")
print(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
status = "failed"
response_text = None
for step_count in range(1, max_iterations + 1):
prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}")
print(f"Prompt: {prompt}")
response = self.llm(
prompt, stop=["Observation:"]
) # TODO: could cause bugs if llm doesn't have `stop` as a parameter
response_text = response.text
logging.info(f"Response: {response_text}")
print(f"Response: {response_text}")
action_step = self._parse_output(response_text)
if action_step is None:
raise ValueError("Invalid action")
is_finished_chain = isinstance(action_step, AgentFinish)
if is_finished_chain:
result = response_text
if "Final Answer:" in response_text:
result = response_text.split("Final Answer:")[-1].strip()
else:
assert isinstance(action_step, AgentAction)
action_name = action_step.tool
tool_input = action_step.tool_input
logging.info(f"Action: {action_name}")
print(f"Action: {action_name}")
logging.info(f"Tool Input: {tool_input}")
print(f"Tool Input: {tool_input}")
result = self._format_function_map()[action_name](tool_input)
# trim the worker output to 1000 tokens, as we are appending
# all workers' logs and it can exceed the token limit if we
# don't limit each. Fix this number regarding to the LLM capacity.
result = self._trim(result)
logging.info(f"Result: {result}")
print(f"Result: {result}")
self.intermediate_steps.append((action_step, result))
if is_finished_chain:
logging.info(f"Finished after {step_count} steps.")
status = "finished"
yield AgentOutput(
text=result,
agent_type=self.agent_type,
status=status,
intermediate_steps=self.intermediate_steps[-1],
)
break
else:
yield AgentOutput(
text="",
agent_type=self.agent_type,
status="thinking",
intermediate_steps=self.intermediate_steps[-1],
)
else:
status = "stopped"
yield AgentOutput(
text="",
agent_type=self.agent_type,
status=status,
intermediate_steps=self.intermediate_steps[-1],
)
return AgentOutput(
text=response_text,
agent_type=self.agent_type,
status=status,
total_tokens=total_token,
total_cost=total_cost,
intermediate_steps=self.intermediate_steps,
max_iterations=max_iterations,
)

View File

@ -3,7 +3,7 @@
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
zero_shot_react_prompt = PromptTemplate( zero_shot_react_prompt = PromptTemplate(
template="""Answer the following questions as best you can. You have access to the following tools: template="""Answer the following questions as best you can. Give answer in {lang}. You have access to the following tools:
{tool_description} {tool_description}
Use the following format: Use the following format:
@ -12,7 +12,7 @@ Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}] Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action Action Input: the input to the action, should be different from the action input of the same action in previous steps.
Observation: the result of the action Observation: the result of the action

View File

@ -1,14 +1,18 @@
import logging import logging
import re import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any from typing import Any
import tiktoken
from kotaemon.agents.base import BaseAgent from kotaemon.agents.base import BaseAgent
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
from kotaemon.agents.tools import BaseTool from kotaemon.agents.tools import BaseTool
from kotaemon.agents.utils import get_plugin_response_content from kotaemon.agents.utils import get_plugin_response_content
from kotaemon.base import Node, Param from kotaemon.base import Document, Node, Param
from kotaemon.indices.qa import CitationPipeline from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import BaseLLM, PromptTemplate from kotaemon.llms import BaseLLM, PromptTemplate
from .planner import Planner from .planner import Planner
@ -22,6 +26,7 @@ class RewooAgent(BaseAgent):
name: str = "RewooAgent" name: str = "RewooAgent"
agent_type: AgentType = AgentType.rewoo agent_type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions" description: str = "RewooAgent for answering multi-step reasoning questions"
output_lang: str = "English"
planner_llm: BaseLLM planner_llm: BaseLLM
solver_llm: BaseLLM solver_llm: BaseLLM
prompt_template: dict[str, PromptTemplate] = Param( prompt_template: dict[str, PromptTemplate] = Param(
@ -34,6 +39,16 @@ class RewooAgent(BaseAgent):
examples: dict[str, str | list[str]] = Param( examples: dict[str, str | list[str]] = Param(
default_callback=lambda _: {}, help="Examples to be used in the agent." default_callback=lambda _: {}, help="Examples to be used in the agent."
) )
trim_func: TokenSplitter = TokenSplitter.withx(
chunk_size=3000,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"]) @Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
def planner(self): def planner(self):
@ -50,6 +65,7 @@ class RewooAgent(BaseAgent):
model=self.solver_llm, model=self.solver_llm,
prompt_template=self.prompt_template.get("Solver", None), prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None), examples=self.examples.get("Solver", None),
output_lang=self.output_lang,
) )
def _parse_plan_map( def _parse_plan_map(
@ -159,8 +175,13 @@ class RewooAgent(BaseAgent):
tool_input = tool_input[:-1] tool_input = tool_input[:-1]
# find variables in input and replace with previous evidences # find variables in input and replace with previous evidences
for var in re.findall(r"#E\d+", tool_input): for var in re.findall(r"#E\d+", tool_input):
print("Tool input: ", tool_input)
print("Var: ", var)
print("Worker evidences: ", worker_evidences)
if var in worker_evidences: if var in worker_evidences:
tool_input = tool_input.replace(var, worker_evidences.get(var, "")) tool_input = tool_input.replace(
var, worker_evidences.get(var, "") or ""
)
try: try:
selected_plugin = self._find_plugin(tool) selected_plugin = self._find_plugin(tool)
if selected_plugin is None: if selected_plugin is None:
@ -216,7 +237,7 @@ class RewooAgent(BaseAgent):
resp = r.result() resp = r.result()
plugin_cost += resp["plugin_cost"] plugin_cost += resp["plugin_cost"]
plugin_token += resp["plugin_token"] plugin_token += resp["plugin_token"]
worker_evidences[resp["e"]] = resp["evidence"] worker_evidences[resp["e"]] = self._trim_evidence(resp["evidence"])
output.done() output.done()
return worker_evidences, plugin_cost, plugin_token return worker_evidences, plugin_cost, plugin_token
@ -226,6 +247,13 @@ class RewooAgent(BaseAgent):
if p.name == name: if p.name == name:
return p return p
def _trim_evidence(self, evidence: str):
if evidence:
texts = self.trim_func([Document(text=evidence)])
evidence = texts[0].text
logging.info(f"len (trimmed): {len(evidence)}")
return evidence
@BaseAgent.safeguard_run @BaseAgent.safeguard_run
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput: def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
""" """
@ -269,5 +297,69 @@ class RewooAgent(BaseAgent):
total_tokens=total_token, total_tokens=total_token,
total_cost=total_cost, total_cost=total_cost,
citation=citation, citation=citation,
metadata={"citation": citation}, metadata={"citation": citation, "worker_log": worker_log},
)
def stream(self, instruction: str, use_citation: bool = False):
"""
Stream the agent with a given instruction.
"""
logging.info(f"Streaming {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
# Plan
planner_output = self.planner(instruction)
planner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(planner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences(
planner_text_output
)
print("Planner output:", planner_text_output)
# Work
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
planner_evidences, evidence_level
)
worker_log = ""
for plan in plan_to_es:
worker_log += f"{plan}: {plans[plan]}\n"
current_progress = f"{plan}: {plans[plan]}\n"
for e in plan_to_es[plan]:
worker_log += f"{e}: {worker_evidences[e]}\n"
current_progress += f"{e}: {worker_evidences[e]}\n"
yield AgentOutput(
text="",
agent_type=self.agent_type,
status="thinking",
intermediate_steps=[{"worker_log": current_progress}],
)
# Solve
solver_response = ""
for solver_output in self.solver.stream(instruction, worker_log):
solver_output_text = solver_output.text
solver_response += solver_output_text
yield AgentOutput(
text=solver_output_text,
agent_type=self.agent_type,
status="thinking",
)
if use_citation:
citation_pipeline = CitationPipeline(llm=self.solver_llm)
citation = citation_pipeline.invoke(
context=worker_log, question=instruction
)
else:
citation = None
return AgentOutput(
text="",
agent_type=self.agent_type,
status="finished",
total_tokens=total_token,
total_cost=total_cost,
citation=citation,
metadata={"citation": citation, "worker_log": worker_log},
) )

View File

@ -81,3 +81,26 @@ class Planner(BaseComponent):
raise ValueError("Planner failed to retrieve response from LLM") from e raise ValueError("Planner failed to retrieve response from LLM") from e
return response return response
def stream(self, instruction: str, output: BaseScratchPad = BaseScratchPad()):
response = None
output.info("Running Planner")
prompt = self._compose_prompt(instruction)
output.debug(f"Prompt: {prompt}")
response = ""
try:
for text in self.model.stream(prompt):
response += text
yield text
self.log_progress(".planner", response=response)
output.info("Planner run successful.")
except NotImplementedError:
print("Streaming is not supported, falling back to normal run")
response = self.model(prompt)
yield response
except ValueError as e:
output.error("Planner failed to retrieve response from LLM")
raise ValueError("Planner failed to retrieve response from LLM") from e
return response

View File

@ -81,7 +81,7 @@ And so on...
zero_shot_solver_prompt = PromptTemplate( zero_shot_solver_prompt = PromptTemplate(
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful. template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
Your task is to briefly summarize each step, then make a short final conclusion for your task. Your task is to briefly summarize each step, then make a short final conclusion for your task. Give answer in {lang}.
##My Plans and Evidences## ##My Plans and Evidences##
{plan_evidence} {plan_evidence}
@ -99,7 +99,7 @@ So, <your conclusion>.
few_shot_solver_prompt = PromptTemplate( few_shot_solver_prompt = PromptTemplate(
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful. template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
Your task is to briefly summarize each step, then make a short final conclusion for your task. Your task is to briefly summarize each step, then make a short final conclusion for your task. Give answer in {lang}.
##My Plans and Evidences## ##My Plans and Evidences##
{plan_evidence} {plan_evidence}

View File

@ -11,6 +11,7 @@ class Solver(BaseComponent):
model: BaseLLM model: BaseLLM
prompt_template: Optional[PromptTemplate] = None prompt_template: Optional[PromptTemplate] = None
examples: Optional[Union[str, List[str]]] = None examples: Optional[Union[str, List[str]]] = None
output_lang: str = "English"
def _compose_fewshot_prompt(self) -> str: def _compose_fewshot_prompt(self) -> str:
if self.examples is None: if self.examples is None:
@ -20,7 +21,7 @@ class Solver(BaseComponent):
else: else:
return "\n\n".join([e.strip("\n") for e in self.examples]) return "\n\n".join([e.strip("\n") for e in self.examples])
def _compose_prompt(self, instruction, plan_evidence) -> str: def _compose_prompt(self, instruction, plan_evidence, output_lang) -> str:
""" """
Compose the prompt from template, plan&evidence, examples and instruction. Compose the prompt from template, plan&evidence, examples and instruction.
""" """
@ -28,20 +29,28 @@ class Solver(BaseComponent):
if self.prompt_template is not None: if self.prompt_template is not None:
if "fewshot" in self.prompt_template.placeholders: if "fewshot" in self.prompt_template.placeholders:
return self.prompt_template.populate( return self.prompt_template.populate(
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction plan_evidence=plan_evidence,
fewshot=fewshot,
task=instruction,
lang=output_lang,
) )
else: else:
return self.prompt_template.populate( return self.prompt_template.populate(
plan_evidence=plan_evidence, task=instruction plan_evidence=plan_evidence, task=instruction, lang=output_lang
) )
else: else:
if self.examples is not None: if self.examples is not None:
return few_shot_solver_prompt.populate( return few_shot_solver_prompt.populate(
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction plan_evidence=plan_evidence,
fewshot=fewshot,
task=instruction,
lang=output_lang,
) )
else: else:
return zero_shot_solver_prompt.populate( return zero_shot_solver_prompt.populate(
plan_evidence=plan_evidence, task=instruction plan_evidence=plan_evidence,
task=instruction,
lang=output_lang,
) )
def run( def run(
@ -54,7 +63,7 @@ class Solver(BaseComponent):
output.info("Running Solver") output.info("Running Solver")
output.debug(f"Instruction: {instruction}") output.debug(f"Instruction: {instruction}")
output.debug(f"Plan Evidence: {plan_evidence}") output.debug(f"Plan Evidence: {plan_evidence}")
prompt = self._compose_prompt(instruction, plan_evidence) prompt = self._compose_prompt(instruction, plan_evidence, self.output_lang)
output.debug(f"Prompt: {prompt}") output.debug(f"Prompt: {prompt}")
try: try:
response = self.model(prompt) response = self.model(prompt)
@ -63,3 +72,28 @@ class Solver(BaseComponent):
output.error("Solver failed to retrieve response from LLM") output.error("Solver failed to retrieve response from LLM")
return response return response
def stream(
self,
instruction: str,
plan_evidence: str,
output: BaseScratchPad = BaseScratchPad(),
) -> Any:
response = ""
output.info("Running Solver")
output.debug(f"Instruction: {instruction}")
output.debug(f"Plan Evidence: {plan_evidence}")
prompt = self._compose_prompt(instruction, plan_evidence, self.output_lang)
output.debug(f"Prompt: {prompt}")
try:
for text in self.model.stream(prompt):
response += text.text
yield text
output.info("Planner run successful.")
except NotImplementedError:
response = self.model(prompt).text
output.info("Solver run successful.")
except ValueError:
output.error("Solver failed to retrieve response from LLM")
return response

View File

@ -1,4 +1,5 @@
from typing import AnyStr, Optional, Type from typing import AnyStr, Optional, Type
from urllib.error import HTTPError
from langchain.utilities import SerpAPIWrapper from langchain.utilities import SerpAPIWrapper
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -26,12 +27,17 @@ class GoogleSearchTool(BaseTool):
"install googlesearch using `pip3 install googlesearch-python` to " "install googlesearch using `pip3 install googlesearch-python` to "
"use this tool" "use this tool"
) )
try:
output = "" output = ""
search_results = search(query, advanced=True) search_results = search(query, advanced=True)
if search_results: if search_results:
output = "\n".join( output = "\n".join(
"{} {}".format(item.title, item.description) for item in search_results "{} {}".format(item.title, item.description)
for item in search_results
) )
except HTTPError:
output = "No evidence found."
return output return output

View File

@ -2,9 +2,10 @@ from typing import AnyStr, Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from kotaemon.agents.tools.base import ToolException
from kotaemon.llms import BaseLLM from kotaemon.llms import BaseLLM
from .base import BaseTool, ToolException from .base import BaseTool
class LLMArgs(BaseModel): class LLMArgs(BaseModel):

View File

@ -0,0 +1,329 @@
import html
import logging
from typing import AnyStr, Optional, Type
from ktem.llms.manager import llms
from ktem.reasoning.base import BaseReasoning
from ktem.utils.generator import Generator
from ktem.utils.render import Render
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
from kotaemon.agents import (
BaseTool,
GoogleSearchTool,
LLMTool,
ReactAgent,
WikipediaTool,
)
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
logger = logging.getLogger(__name__)
class DocSearchArgs(BaseModel):
query: str = Field(..., description="a search query as input to the doc search")
class DocSearchTool(BaseTool):
name: str = "docsearch"
description: str = (
"A storage that contains internal documents. If you lack any specific "
"private information to answer the question, you can search in this "
"document storage. Furthermore, if you are unsure about which document that "
"the user refers to, likely the user already selects the target document in "
"this document storage, you just need to do normal search. If possible, "
"formulate the search query as specific as possible."
)
args_schema: Optional[Type[BaseModel]] = DocSearchArgs
retrievers: list[BaseComponent] = []
def _run_tool(self, query: AnyStr) -> AnyStr:
docs = []
doc_ids = []
for retriever in self.retrievers:
for doc in retriever(text=query):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
return self.prepare_evidence(docs)
def prepare_evidence(self, docs, trim_len: int = 4000):
evidence = ""
table_found = 0
for _id, retrieved_item in enumerate(docs):
retrieved_content = ""
page = retrieved_item.metadata.get("page_label", None)
source = filename = retrieved_item.metadata.get("file_name", "-")
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
if table_found < 5:
retrieved_content = retrieved_item.metadata.get("table_origin", "")
if retrieved_content not in evidence:
table_found += 1
evidence += (
f"<br><b>Table from {source}</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "image":
retrieved_content = retrieved_item.metadata.get("image_origin", "")
retrieved_caption = html.escape(retrieved_item.get_content())
evidence += (
f"<br><b>Figure from {source}</b>\n" + retrieved_caption + "\n<br>"
)
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
else:
retrieved_content = retrieved_item.text
retrieved_content = retrieved_content.replace("\n", " ")
if retrieved_content not in evidence:
evidence += (
f"<br><b>Content from {source}: </b> "
+ retrieved_content
+ " \n<br>"
)
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
print("Score", retrieved_item.metadata.get("relevance_score", None))
# trim context by trim_len
if evidence:
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=trim_len,
chunk_overlap=0,
separator=" ",
model_name="gpt-3.5-turbo",
)
texts = text_splitter.split_text(evidence)
evidence = texts[0]
return Document(content=evidence)
TOOL_REGISTRY = {
"Google": GoogleSearchTool(),
"Wikipedia": WikipediaTool(),
"LLM": LLMTool(),
"SearchDoc": DocSearchTool(),
}
DEFAULT_QA_PROMPT = (
"Answer the following questions as best you can. Give answer in {lang}. "
"You have access to the following tools:\n"
"{tool_description}\n"
"Use the following format:\n\n"
"Question: the input question you must answer\n"
"Thought: you should always think about what to do\n\n"
"Action: the action to take, should be one of [{tool_names}]\n\n"
"Action Input: the input to the action, should be different from the action input "
"of the same action in previous steps.\n\n"
"Observation: the result of the action\n\n"
"... (this Thought/Action/Action Input/Observation can repeat N times)\n"
"#Thought: I now know the final answer\n"
"Final Answer: the final answer to the original input question\n\n"
"Begin! After each Action Input.\n\n"
"Question: {instruction}\n"
"Thought: {agent_scratchpad}\n"
)
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
)
class RewriteQuestionPipeline(BaseComponent):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
return self.llm(messages)
class ReactAgentPipeline(BaseReasoning):
"""Question answering pipeline using ReAct agent."""
class Config:
allow_extra = True
retrievers: list[BaseComponent]
agent: ReactAgent = ReactAgent.withx()
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False
def prepare_citation(self, step_id, step, output, status) -> Document:
header = "<b>Step {id}</b>: {log}".format(id=step_id, log=step.log)
content = (
"<b>Action</b>: <em>{tool}[{input}]</em>\n\n<b>Output</b>: {output}"
).format(
tool=step.tool if status == "thinking" else "",
input=step.tool_input.replace("\n", "") if status == "thinking" else "",
output=output if status == "thinking" else "Finished",
)
return Document(
channel="info",
content=Render.collapsible(
header=header,
content=Render.table(content),
open=True,
),
)
async def ainvoke( # type: ignore
self, message, conv_id: str, history: list, **kwargs # type: ignore
) -> Document:
if self.use_rewrite:
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
answer = self.agent(message)
self.report_output(Document(content=answer.text, channel="chat"))
intermediate_steps = answer.intermediate_steps
for _, step_output in intermediate_steps:
self.report_output(Document(content=step_output, channel="info"))
self.report_output(None)
return answer
def stream(self, message, conv_id: str, history: list, **kwargs):
if self.use_rewrite:
rewrite = self.rewrite_pipeline(question=message)
message = rewrite.text
yield Document(
channel="info",
content=f"Rewrote the message to: {rewrite.text}",
)
output_stream = Generator(self.agent.stream(message))
idx = 0
for item in output_stream:
idx += 1
if item.status == "thinking":
step, step_output = item.intermediate_steps
yield Document(
channel="info",
content=self.prepare_citation(idx, step, step_output, item.status),
)
else:
yield Document(
channel="chat",
content=item.text,
)
step, step_output = item.intermediate_steps
yield Document(
channel="info",
content=self.prepare_citation(idx, step, step_output, item.status),
)
return output_stream.value
@classmethod
def get_pipeline(
cls, settings: dict, states: dict, retrievers: list | None = None
) -> BaseReasoning:
_id = cls.get_info()["id"]
prefix = f"reasoning.options.{_id}"
llm_name = settings[f"{prefix}.llm"]
llm = llms.get(llm_name, llms.get_default())
pipeline = ReactAgentPipeline(retrievers=retrievers)
pipeline.agent.llm = llm
pipeline.agent.max_iterations = settings[f"{prefix}.max_iterations"]
tools = []
for tool_name in settings[f"reasoning.options.{_id}.tools"]:
tool = TOOL_REGISTRY[tool_name]
if tool_name == "SearchDoc":
tool.retrievers = retrievers
elif tool_name == "LLM":
tool.llm = llm
tools.append(tool)
pipeline.agent.plugins = tools
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.agent.prompt_template = PromptTemplate(settings[f"{prefix}.qa_prompt"])
return pipeline
@classmethod
def get_user_settings(cls) -> dict:
llm = ""
llm_choices = [("(default)", "")]
try:
llm_choices += [(_, _) for _ in llms.options().keys()]
except Exception as e:
logger.exception(f"Failed to get LLM options: {e}")
tool_choices = ["Wikipedia", "Google", "LLM", "SearchDoc"]
return {
"llm": {
"name": "Language model",
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"info": (
"The language model to use for generating the answer. If None, "
"the application default language model will be used."
),
},
"tools": {
"name": "Tools for knowledge retrieval",
"value": ["SearchDoc", "LLM"],
"component": "checkboxgroup",
"choices": tool_choices,
},
"max_iterations": {
"name": "Maximum number of iterations the LLM can go through",
"value": 5,
"component": "number",
},
"qa_prompt": {
"name": "QA Prompt",
"value": DEFAULT_QA_PROMPT,
},
}
@classmethod
def get_info(cls) -> dict:
return {
"id": "ReAct",
"name": "ReAct Agent",
"description": "Implementing ReAct paradigm",
}

View File

@ -0,0 +1,462 @@
import html
import logging
from difflib import SequenceMatcher
from typing import AnyStr, Generator, Optional, Type
from ktem.llms.manager import llms
from ktem.reasoning.base import BaseReasoning
from ktem.utils.generator import Generator as GeneratorWrapper
from ktem.utils.render import Render
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
from kotaemon.agents import (
BaseTool,
GoogleSearchTool,
LLMTool,
RewooAgent,
WikipediaTool,
)
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
logger = logging.getLogger(__name__)
DEFAULT_PLANNER_PROMPT = (
"You are an AI agent who makes step-by-step plans to solve a problem under the "
"help of external tools. For each step, make one plan followed by one tool-call, "
"which will be executed later to retrieve evidence for that step.\n"
"You should store each evidence into a distinct variable #E1, #E2, #E3 ... that "
"can be referred to in later tool-call inputs.\n\n"
"##Available Tools##\n"
"{tool_description}\n\n"
"##Output Format (Replace '<...>')##\n"
"#Plan1: <describe your plan here>\n"
"#E1: <toolname>[<input here>] (eg. Search[What is Python])\n"
"#Plan2: <describe next plan>\n"
"#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]\n"
"And so on...\n\n"
"##Your Task##\n"
"{task}\n\n"
"##Now Begin##\n"
)
DEFAULT_SOLVER_PROMPT = (
"You are an AI agent who solves a problem with my assistance. I will provide "
"step-by-step plans(#Plan) and evidences(#E) that could be helpful.\n"
"Your task is to briefly summarize each step, then make a short final conclusion "
"for your task. Give answer in {lang}.\n\n"
"##My Plans and Evidences##\n"
"{plan_evidence}\n\n"
"##Example Output##\n"
"First, I <did something> , and I think <...>; Second, I <...>, "
"and I think <...>; ....\n"
"So, <your conclusion>.\n\n"
"##Your Task##\n"
"{task}\n\n"
"##Now Begin##\n"
)
class DocSearchArgs(BaseModel):
query: str = Field(..., description="a search query as input to the doc search")
class DocSearchTool(BaseTool):
name: str = "docsearch"
description: str = (
"A storage that contains internal documents. If you lack any specific "
"private information to answer the question, you can search in this "
"document storage. Furthermore, if you are unsure about which document that "
"the user refers to, likely the user already selects the target document in "
"this document storage, you just need to do normal search. If possible, "
"formulate the search query as specific as possible."
)
args_schema: Optional[Type[BaseModel]] = DocSearchArgs
retrievers: list[BaseComponent] = []
def _run_tool(self, query: AnyStr) -> AnyStr:
docs = []
doc_ids = []
for retriever in self.retrievers:
for doc in retriever(text=query):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
return self.prepare_evidence(docs)
def prepare_evidence(self, docs, trim_len: int = 3000):
evidence = ""
table_found = 0
for _id, retrieved_item in enumerate(docs):
retrieved_content = ""
page = retrieved_item.metadata.get("page_label", None)
source = filename = retrieved_item.metadata.get("file_name", "-")
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
if table_found < 5:
retrieved_content = retrieved_item.metadata.get("table_origin", "")
if retrieved_content not in evidence:
table_found += 1
evidence += (
f"<br><b>Table from {source}</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "image":
retrieved_content = retrieved_item.metadata.get("image_origin", "")
retrieved_caption = html.escape(retrieved_item.get_content())
# PWS doesn't support VLM for images, we will just store the caption
evidence += (
f"<br><b>Figure from {source}</b>\n" + retrieved_caption + "\n<br>"
)
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
else:
retrieved_content = retrieved_item.text
retrieved_content = retrieved_content.replace("\n", " ")
if retrieved_content not in evidence:
evidence += (
f"<br><b>Content from {source}: </b> "
+ retrieved_content
+ " \n<br>"
)
print("Retrieved #{}: {}".format(_id, retrieved_content))
print("Score", retrieved_item.metadata.get("relevance_score", None))
# trim context by trim_len
if evidence:
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=trim_len,
chunk_overlap=0,
separator=" ",
model_name="gpt-3.5-turbo",
)
texts = text_splitter.split_text(evidence)
evidence = texts[0]
return Document(content=evidence)
TOOL_REGISTRY = {
"Google": GoogleSearchTool(),
"Wikipedia": WikipediaTool(),
"LLM": LLMTool(),
"SearchDoc": DocSearchTool(),
}
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
)
class RewriteQuestionPipeline(BaseComponent):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
return self.llm(messages)
def find_text(llm_output, context):
sentence_list = llm_output.split("\n")
matches = []
for sentence in sentence_list:
match = SequenceMatcher(
None, sentence, context, autojunk=False
).find_longest_match()
matches.append((match.b, match.b + match.size))
return matches
class RewooAgentPipeline(BaseReasoning):
"""Question answering pipeline using ReWOO Agent."""
class Config:
allow_extra = True
retrievers: list[BaseComponent]
agent: RewooAgent = RewooAgent.withx()
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False
enable_citation: bool = False
def format_info_panel(self, worker_log):
header = ""
content = []
for line in worker_log.splitlines():
if line.startswith("#Plan"):
# line starts with #Plan should be marked as a new segment
header = line
elif line.startswith("#"):
# stop markdown from rendering big headers
line = "\\" + line
content.append(line)
else:
content.append(line)
if not header:
return
return Document(
channel="info",
content=Render.collapsible(
header=header,
content=Render.table("\n".join(content)),
open=True,
),
)
def prepare_citation(self, answer) -> list[Document]:
"""Prepare citation to show on the UI"""
segments = []
split_indices = [
0,
]
start_indices = set()
text = ""
if "citation" in answer.metadata and answer.metadata["citation"] is not None:
context = answer.metadata["worker_log"]
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
matches = find_text(quote, context)
for match in matches:
split_indices.append(match[0])
split_indices.append(match[1])
start_indices.add(match[0])
split_indices = sorted(list(set(split_indices)))
spans = []
prev = 0
for index in split_indices:
if index > prev:
spans.append(context[prev:index])
prev = index
spans.append(context[split_indices[-1] :])
prev = 0
for span, start_idx in list(zip(spans, split_indices)):
if start_idx in start_indices:
text += Render.highlight(span)
else:
text += span
else:
text = answer.metadata["worker_log"]
# separate text by detect header: #Plan
for line in text.splitlines():
if line.startswith("#Plan"):
# line starts with #Plan should be marked as a new segment
new_segment = [line]
segments.append(new_segment)
elif line.startswith("#"):
# stop markdown from rendering big headers
line = "\\" + line
segments[-1].append(line)
else:
segments[-1].append(line)
outputs = []
for segment in segments:
outputs.append(
Document(
channel="info",
content=Render.collapsible(
header=segment[0],
content=Render.table("\n".join(segment[1:])),
open=True,
),
)
)
return outputs
async def ainvoke( # type: ignore
self, message, conv_id: str, history: list, **kwargs # type: ignore
) -> Document:
answer = self.agent(message, use_citation=True)
self.report_output(Document(content=answer.text, channel="chat"))
refined_citations = self.prepare_citation(answer)
for _ in refined_citations:
self.report_output(_)
self.report_output(None)
return answer
def stream( # type: ignore
self, message, conv_id: str, history: list, **kwargs # type: ignore
) -> Generator[Document, None, Document] | None:
if self.use_rewrite:
rewrite = self.rewrite_pipeline(question=message)
message = rewrite.text
yield Document(
channel="info",
content=f"Rewrote the message to: {rewrite.text}",
)
output_stream = GeneratorWrapper(
self.agent.stream(message, use_citation=self.enable_citation)
)
for item in output_stream:
if item.intermediate_steps:
for step in item.intermediate_steps:
yield Document(
channel="info",
content=self.format_info_panel(step["worker_log"]),
)
if item.text:
yield Document(channel="chat", content=item.text)
answer = output_stream.value
yield Document(channel="info", content=None)
refined_citations = self.prepare_citation(answer)
for _ in refined_citations:
yield _
return answer
@classmethod
def get_pipeline(
cls, settings: dict, states: dict, retrievers: list | None = None
) -> BaseReasoning:
_id = cls.get_info()["id"]
prefix = f"reasoning.options.{_id}"
pipeline = RewooAgentPipeline(retrievers=retrievers)
planner_llm_name = settings[f"{prefix}.planner_llm"]
planner_llm = llms.get(planner_llm_name, llms.get_default())
solver_llm_name = settings[f"{prefix}.solver_llm"]
solver_llm = llms.get(solver_llm_name, llms.get_default())
pipeline.agent.planner_llm = planner_llm
pipeline.agent.solver_llm = solver_llm
tools = []
for tool_name in settings[f"{prefix}.tools"]:
tool = TOOL_REGISTRY[tool_name]
if tool_name == "SearchDoc":
tool.retrievers = retrievers
elif tool_name == "LLM":
tool.llm = solver_llm
tools.append(tool)
pipeline.agent.plugins = tools
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
pipeline.agent.prompt_template["Planner"] = PromptTemplate(
settings[f"{prefix}.planner_prompt"]
)
pipeline.agent.prompt_template["Solver"] = PromptTemplate(
settings[f"{prefix}.solver_prompt"]
)
pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.rewrite_pipeline.llm = (
planner_llm # TODO: separate llm for rewrite if needed
)
return pipeline
@classmethod
def get_user_settings(cls) -> dict:
llm = ""
llm_choices = [("(default)", "")]
try:
llm_choices += [(_, _) for _ in llms.options().keys()]
except Exception as e:
logger.exception(f"Failed to get LLM options: {e}")
tool_choices = ["Wikipedia", "Google", "LLM", "SearchDoc"]
return {
"planner_llm": {
"name": "Language model for Planner",
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"info": (
"The language model to use for planning. "
"This model will generate a plan based on the "
"instruction to find the answer."
),
},
"solver_llm": {
"name": "Language model for Solver",
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"info": (
"The language model to use for solving. "
"This model will generate the answer based on the "
"plan generated by the planner and evidences found by the tools."
),
},
"highlight_citation": {
"name": "Highlight Citation",
"value": False,
"component": "checkbox",
},
"tools": {
"name": "Tools for knowledge retrieval",
"value": ["SearchDoc", "LLM"],
"component": "checkboxgroup",
"choices": tool_choices,
},
"planner_prompt": {
"name": "Planner Prompt",
"value": DEFAULT_PLANNER_PROMPT,
},
"solver_prompt": {
"name": "Solver Prompt",
"value": DEFAULT_SOLVER_PROMPT,
},
}
@classmethod
def get_info(cls) -> dict:
return {
"id": "ReWOO",
"name": "ReWOO Agent",
"description": (
"Implementing ReWOO paradigm " "https://arxiv.org/pdf/2305.18323.pdf"
),
}

View File

@ -0,0 +1,9 @@
class Generator:
"""A generator that stores return value from another generator"""
def __init__(self, gen):
self.gen = gen
def __iter__(self):
self.value = yield from self.gen
return self.value