kotaemon/libs/ktem/ktem/reasoning/react.py
trducng ebf1315569
(pump:minor) Allow the indexing pipeline to report the indexing progress onto the UI (#81)
* Turn the file indexing event to generator to report progress

* Fix React text's trimming function

* Refactor delete file into a method
2024-05-25 22:09:41 +07:00

330 lines
12 KiB
Python

import html
import logging
from typing import AnyStr, Optional, Type
from ktem.llms.manager import llms
from ktem.reasoning.base import BaseReasoning
from ktem.utils.generator import Generator
from ktem.utils.render import Render
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
from kotaemon.agents import (
BaseTool,
GoogleSearchTool,
LLMTool,
ReactAgent,
WikipediaTool,
)
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
logger = logging.getLogger(__name__)
class DocSearchArgs(BaseModel):
query: str = Field(..., description="a search query as input to the doc search")
class DocSearchTool(BaseTool):
name: str = "docsearch"
description: str = (
"A storage that contains internal documents. If you lack any specific "
"private information to answer the question, you can search in this "
"document storage. Furthermore, if you are unsure about which document that "
"the user refers to, likely the user already selects the target document in "
"this document storage, you just need to do normal search. If possible, "
"formulate the search query as specific as possible."
)
args_schema: Optional[Type[BaseModel]] = DocSearchArgs
retrievers: list[BaseComponent] = []
def _run_tool(self, query: AnyStr) -> AnyStr:
docs = []
doc_ids = []
for retriever in self.retrievers:
for doc in retriever(text=query):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
return self.prepare_evidence(docs)
def prepare_evidence(self, docs, trim_len: int = 4000):
evidence = ""
table_found = 0
for _id, retrieved_item in enumerate(docs):
retrieved_content = ""
page = retrieved_item.metadata.get("page_label", None)
source = filename = retrieved_item.metadata.get("file_name", "-")
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
if table_found < 5:
retrieved_content = retrieved_item.metadata.get("table_origin", "")
if retrieved_content not in evidence:
table_found += 1
evidence += (
f"<br><b>Table from {source}</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "image":
retrieved_content = retrieved_item.metadata.get("image_origin", "")
retrieved_caption = html.escape(retrieved_item.get_content())
evidence += (
f"<br><b>Figure from {source}</b>\n" + retrieved_caption + "\n<br>"
)
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
else:
retrieved_content = retrieved_item.text
retrieved_content = retrieved_content.replace("\n", " ")
if retrieved_content not in evidence:
evidence += (
f"<br><b>Content from {source}: </b> "
+ retrieved_content
+ " \n<br>"
)
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
print("Score", retrieved_item.metadata.get("relevance_score", None))
# trim context by trim_len
if evidence:
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=trim_len,
chunk_overlap=0,
separator=" ",
model_name="gpt-3.5-turbo",
)
texts = text_splitter.split_text(evidence)
evidence = texts[0]
return Document(content=evidence)
TOOL_REGISTRY = {
"Google": GoogleSearchTool(),
"Wikipedia": WikipediaTool(),
"LLM": LLMTool(),
"SearchDoc": DocSearchTool(),
}
DEFAULT_QA_PROMPT = (
"Answer the following questions as best you can. Give answer in {lang}. "
"You have access to the following tools:\n"
"{tool_description}\n"
"Use the following format:\n\n"
"Question: the input question you must answer\n"
"Thought: you should always think about what to do\n\n"
"Action: the action to take, should be one of [{tool_names}]\n\n"
"Action Input: the input to the action, should be different from the action input "
"of the same action in previous steps.\n\n"
"Observation: the result of the action\n\n"
"... (this Thought/Action/Action Input/Observation can repeat N times)\n"
"#Thought: I now know the final answer\n"
"Final Answer: the final answer to the original input question\n\n"
"Begin! After each Action Input.\n\n"
"Question: {instruction}\n"
"Thought: {agent_scratchpad}\n"
)
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
)
class RewriteQuestionPipeline(BaseComponent):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
return self.llm(messages)
class ReactAgentPipeline(BaseReasoning):
"""Question answering pipeline using ReAct agent."""
class Config:
allow_extra = True
retrievers: list[BaseComponent]
agent: ReactAgent = ReactAgent.withx()
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False
def prepare_citation(self, step_id, step, output, status) -> Document:
header = "<b>Step {id}</b>: {log}".format(id=step_id, log=step.log)
content = (
"<b>Action</b>: <em>{tool}[{input}]</em>\n\n<b>Output</b>: {output}"
).format(
tool=step.tool if status == "thinking" else "",
input=step.tool_input.replace("\n", "") if status == "thinking" else "",
output=output if status == "thinking" else "Finished",
)
return Document(
channel="info",
content=Render.collapsible(
header=header,
content=Render.table(content),
open=True,
),
)
async def ainvoke( # type: ignore
self, message, conv_id: str, history: list, **kwargs # type: ignore
) -> Document:
if self.use_rewrite:
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
answer = self.agent(message)
self.report_output(Document(content=answer.text, channel="chat"))
intermediate_steps = answer.intermediate_steps
for _, step_output in intermediate_steps:
self.report_output(Document(content=step_output, channel="info"))
self.report_output(None)
return answer
def stream(self, message, conv_id: str, history: list, **kwargs):
if self.use_rewrite:
rewrite = self.rewrite_pipeline(question=message)
message = rewrite.text
yield Document(
channel="info",
content=f"Rewrote the message to: {rewrite.text}",
)
output_stream = Generator(self.agent.stream(message))
idx = 0
for item in output_stream:
idx += 1
if item.status == "thinking":
step, step_output = item.intermediate_steps
yield Document(
channel="info",
content=self.prepare_citation(idx, step, step_output, item.status),
)
else:
yield Document(
channel="chat",
content=item.text,
)
step, step_output = item.intermediate_steps
yield Document(
channel="info",
content=self.prepare_citation(idx, step, step_output, item.status),
)
return output_stream.value
@classmethod
def get_pipeline(
cls, settings: dict, states: dict, retrievers: list | None = None
) -> BaseReasoning:
_id = cls.get_info()["id"]
prefix = f"reasoning.options.{_id}"
llm_name = settings[f"{prefix}.llm"]
llm = llms.get(llm_name, llms.get_default())
pipeline = ReactAgentPipeline(retrievers=retrievers)
pipeline.agent.llm = llm
pipeline.agent.max_iterations = settings[f"{prefix}.max_iterations"]
tools = []
for tool_name in settings[f"reasoning.options.{_id}.tools"]:
tool = TOOL_REGISTRY[tool_name]
if tool_name == "SearchDoc":
tool.retrievers = retrievers
elif tool_name == "LLM":
tool.llm = llm
tools.append(tool)
pipeline.agent.plugins = tools
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.agent.prompt_template = PromptTemplate(settings[f"{prefix}.qa_prompt"])
return pipeline
@classmethod
def get_user_settings(cls) -> dict:
llm = ""
llm_choices = [("(default)", "")]
try:
llm_choices += [(_, _) for _ in llms.options().keys()]
except Exception as e:
logger.exception(f"Failed to get LLM options: {e}")
tool_choices = ["Wikipedia", "Google", "LLM", "SearchDoc"]
return {
"llm": {
"name": "Language model",
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"info": (
"The language model to use for generating the answer. If None, "
"the application default language model will be used."
),
},
"tools": {
"name": "Tools for knowledge retrieval",
"value": ["SearchDoc", "LLM"],
"component": "checkboxgroup",
"choices": tool_choices,
},
"max_iterations": {
"name": "Maximum number of iterations the LLM can go through",
"value": 5,
"component": "number",
},
"qa_prompt": {
"name": "QA Prompt",
"value": DEFAULT_QA_PROMPT,
},
}
@classmethod
def get_info(cls) -> dict:
return {
"id": "ReAct",
"name": "ReAct Agent",
"description": "Implementing ReAct paradigm",
}