Add relevant chat context when query the index (#42)

* Add context for query

* Add older messages in the chat

* Update the indexing

* Make some hard-code values configurable

* Remove hard-code values
This commit is contained in:
Duc Nguyen (john) 2024-04-22 14:32:30 +07:00 committed by GitHub
parent 749c9e5641
commit fbe983ccb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 112 additions and 17 deletions

View File

@ -47,6 +47,8 @@ class DocumentIngestor(BaseComponent):
text_splitter: BaseSplitter = TokenSplitter.withx(
chunk_size=1024,
chunk_overlap=256,
separator="\n\n",
backup_separators=["\n", ".", " ", "\u200B"],
)
override_file_extractors: dict[str, Type[BaseReader]] = {}

View File

@ -11,6 +11,7 @@ from ktem.llms.manager import llms
from ktem.utils.render import Render
from kotaemon.base import (
AIMessage,
BaseComponent,
Document,
HumanMessage,
@ -205,6 +206,7 @@ class AnswerWithContextPipeline(BaseComponent):
enable_citation: bool = False
system_prompt: str = ""
lang: str = "English" # support English and Japanese
n_last_interactions: int = 5
def get_prompt(self, question, evidence, evidence_mode: int):
"""Prepare the prompt and other information for LLM"""
@ -244,6 +246,7 @@ class AnswerWithContextPipeline(BaseComponent):
def invoke(
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
history = kwargs.get("history", [])
prompt, images = self.get_prompt(question, evidence, evidence_mode)
output = ""
@ -253,6 +256,9 @@ class AnswerWithContextPipeline(BaseComponent):
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
for human, ai in history[-self.n_last_interactions :]:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=prompt))
output = self.llm(messages).text
@ -292,6 +298,7 @@ class AnswerWithContextPipeline(BaseComponent):
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
history = kwargs.get("history", [])
prompt, images = self.get_prompt(question, evidence, evidence_mode)
citation_task = None
@ -311,6 +318,9 @@ class AnswerWithContextPipeline(BaseComponent):
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
for human, ai in history[-self.n_last_interactions :]:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=prompt))
try:
@ -339,6 +349,7 @@ class AnswerWithContextPipeline(BaseComponent):
def stream( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Generator[Document, None, Document]:
history = kwargs.get("history", [])
prompt, images = self.get_prompt(question, evidence, evidence_mode)
output = ""
@ -350,6 +361,9 @@ class AnswerWithContextPipeline(BaseComponent):
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
for human, ai in history[-self.n_last_interactions :]:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=prompt))
try:
@ -406,6 +420,50 @@ class RewriteQuestionPipeline(BaseComponent):
return self.llm(messages)
class AddQueryContextPipeline(BaseComponent):
n_last_interactions: int = 5
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
def run(self, question: str, history: list) -> Document:
messages = [
SystemMessage(
content="Below is a history of the conversation so far, and a new "
"question asked by the user that needs to be answered by searching "
"in a knowledge base.\nYou have access to a Search index "
"with 100's of documents.\nGenerate a search query based on the "
"conversation and the new question.\nDo not include cited source "
"filenames and document names e.g info.txt or doc.pdf in the search "
"query terms.\nDo not include any text inside [] or <<>> in the "
"search query terms.\nDo not include any special characters like "
"'+'.\nIf the question is not in English, rewrite the query in "
"the language used in the question.\n If the question contains enough "
"information, return just the number 1\n If it's unnecessary to do "
"the searching, return just the number 0."
),
HumanMessage(content="How did crypto do last year?"),
AIMessage(
content="Summarize Cryptocurrency Market Dynamics from last year"
),
HumanMessage(content="What are my health plans?"),
AIMessage(content="Show available health plans"),
]
for human, ai in history[-self.n_last_interactions :]:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=f"Generate search query for: {question}"))
resp = self.llm(messages).text
if resp == "0":
return Document(content="")
if resp == "1":
return Document(content=question)
return Document(content=resp)
class FullQAPipeline(BaseReasoning):
"""Question answering pipeline. Handle from question to answer"""
@ -417,13 +475,29 @@ class FullQAPipeline(BaseReasoning):
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()
trigger_context: int = 150
use_rewrite: bool = False
def retrieve(self, message: str) -> tuple[list[RetrievedDocument], list[Document]]:
def retrieve(
self, message: str, history: list
) -> tuple[list[RetrievedDocument], list[Document]]:
"""Retrieve the documents based on the message"""
if len(message) < self.trigger_context:
# prefer adding context for short user questions, avoid adding context for
# long questions, as they are likely to contain enough information
# plus, avoid the situation where the original message is already too long
# for the model to handle
query = self.add_query_context(message, history).content
else:
query = message
print(f"Rewritten query: {query}")
if not query:
return [], []
docs, doc_ids = [], []
for retriever in self.retrievers:
for doc in retriever(text=message):
for doc in retriever(text=query):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
@ -522,7 +596,7 @@ class FullQAPipeline(BaseReasoning):
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
docs, infos = self.retrieve(message)
docs, infos = self.retrieve(message, history)
for _ in infos:
self.report_output(_)
await asyncio.sleep(0.1)
@ -564,7 +638,8 @@ class FullQAPipeline(BaseReasoning):
if self.use_rewrite:
message = self.rewrite_pipeline(question=message).text
docs, infos = self.retrieve(message)
# should populate the context
docs, infos = self.retrieve(message, history)
for _ in infos:
yield _
@ -604,24 +679,27 @@ class FullQAPipeline(BaseReasoning):
settings: the settings for the pipeline
retrievers: the retrievers to use
"""
_id = cls.get_info()["id"]
prefix = f"reasoning.options.{cls.get_info()['id']}"
pipeline = FullQAPipeline(retrievers=retrievers)
pipeline.answering_pipeline.llm = llms.get_default()
pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default()
pipeline.answering_pipeline.enable_citation = settings[
f"reasoning.options.{_id}.highlight_citation"
]
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
# answering pipeline configuration
answer_pipeline = pipeline.answering_pipeline
answer_pipeline.llm = llms.get_default()
answer_pipeline.citation_pipeline.llm = llms.get_default()
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
answer_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
pipeline.answering_pipeline.system_prompt = settings[
f"reasoning.options.{_id}.system_prompt"
]
pipeline.answering_pipeline.qa_template = settings[
f"reasoning.options.{_id}.qa_prompt"
pipeline.add_query_context.llm = llms.get_default()
pipeline.add_query_context.n_last_interactions = settings[
f"{prefix}.n_last_interactions"
]
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.rewrite_pipeline.llm = llms.get_default()
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
@ -645,6 +723,21 @@ class FullQAPipeline(BaseReasoning):
"name": "QA Prompt (contains {context}, {question}, {lang})",
"value": DEFAULT_QA_TEXT_PROMPT,
},
"n_last_interactions": {
"name": "Number of interactions to include",
"value": 5,
"component": "number",
"info": "The maximum number of chat interactions to include in the LLM",
},
"trigger_context": {
"name": "Maximum message length for context rewriting",
"value": 150,
"component": "number",
"info": (
"The maximum length of the message to trigger context addition. "
"Exceeding this length, the message will be used as is."
),
},
}
@classmethod