Add relevant chat context when query the index (#42)
* Add context for query * Add older messages in the chat * Update the indexing * Make some hard-code values configurable * Remove hard-code values
This commit is contained in:
parent
749c9e5641
commit
fbe983ccb3
|
@ -47,6 +47,8 @@ class DocumentIngestor(BaseComponent):
|
||||||
text_splitter: BaseSplitter = TokenSplitter.withx(
|
text_splitter: BaseSplitter = TokenSplitter.withx(
|
||||||
chunk_size=1024,
|
chunk_size=1024,
|
||||||
chunk_overlap=256,
|
chunk_overlap=256,
|
||||||
|
separator="\n\n",
|
||||||
|
backup_separators=["\n", ".", " ", "\u200B"],
|
||||||
)
|
)
|
||||||
override_file_extractors: dict[str, Type[BaseReader]] = {}
|
override_file_extractors: dict[str, Type[BaseReader]] = {}
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ktem.llms.manager import llms
|
||||||
from ktem.utils.render import Render
|
from ktem.utils.render import Render
|
||||||
|
|
||||||
from kotaemon.base import (
|
from kotaemon.base import (
|
||||||
|
AIMessage,
|
||||||
BaseComponent,
|
BaseComponent,
|
||||||
Document,
|
Document,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
@ -205,6 +206,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
enable_citation: bool = False
|
enable_citation: bool = False
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
lang: str = "English" # support English and Japanese
|
lang: str = "English" # support English and Japanese
|
||||||
|
n_last_interactions: int = 5
|
||||||
|
|
||||||
def get_prompt(self, question, evidence, evidence_mode: int):
|
def get_prompt(self, question, evidence, evidence_mode: int):
|
||||||
"""Prepare the prompt and other information for LLM"""
|
"""Prepare the prompt and other information for LLM"""
|
||||||
|
@ -244,6 +246,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
def invoke(
|
def invoke(
|
||||||
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||||
) -> Document:
|
) -> Document:
|
||||||
|
history = kwargs.get("history", [])
|
||||||
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
|
@ -253,6 +256,9 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
messages = []
|
messages = []
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
|
for human, ai in history[-self.n_last_interactions :]:
|
||||||
|
messages.append(HumanMessage(content=human))
|
||||||
|
messages.append(AIMessage(content=ai))
|
||||||
messages.append(HumanMessage(content=prompt))
|
messages.append(HumanMessage(content=prompt))
|
||||||
output = self.llm(messages).text
|
output = self.llm(messages).text
|
||||||
|
|
||||||
|
@ -292,6 +298,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
(determined by retrieval pipeline)
|
(determined by retrieval pipeline)
|
||||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||||
"""
|
"""
|
||||||
|
history = kwargs.get("history", [])
|
||||||
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
||||||
|
|
||||||
citation_task = None
|
citation_task = None
|
||||||
|
@ -311,6 +318,9 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
messages = []
|
messages = []
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
|
for human, ai in history[-self.n_last_interactions :]:
|
||||||
|
messages.append(HumanMessage(content=human))
|
||||||
|
messages.append(AIMessage(content=ai))
|
||||||
messages.append(HumanMessage(content=prompt))
|
messages.append(HumanMessage(content=prompt))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -339,6 +349,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
def stream( # type: ignore
|
def stream( # type: ignore
|
||||||
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||||
) -> Generator[Document, None, Document]:
|
) -> Generator[Document, None, Document]:
|
||||||
|
history = kwargs.get("history", [])
|
||||||
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
|
@ -350,6 +361,9 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
messages = []
|
messages = []
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
|
for human, ai in history[-self.n_last_interactions :]:
|
||||||
|
messages.append(HumanMessage(content=human))
|
||||||
|
messages.append(AIMessage(content=ai))
|
||||||
messages.append(HumanMessage(content=prompt))
|
messages.append(HumanMessage(content=prompt))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -406,6 +420,50 @@ class RewriteQuestionPipeline(BaseComponent):
|
||||||
return self.llm(messages)
|
return self.llm(messages)
|
||||||
|
|
||||||
|
|
||||||
|
class AddQueryContextPipeline(BaseComponent):
|
||||||
|
|
||||||
|
n_last_interactions: int = 5
|
||||||
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||||
|
|
||||||
|
def run(self, question: str, history: list) -> Document:
|
||||||
|
messages = [
|
||||||
|
SystemMessage(
|
||||||
|
content="Below is a history of the conversation so far, and a new "
|
||||||
|
"question asked by the user that needs to be answered by searching "
|
||||||
|
"in a knowledge base.\nYou have access to a Search index "
|
||||||
|
"with 100's of documents.\nGenerate a search query based on the "
|
||||||
|
"conversation and the new question.\nDo not include cited source "
|
||||||
|
"filenames and document names e.g info.txt or doc.pdf in the search "
|
||||||
|
"query terms.\nDo not include any text inside [] or <<>> in the "
|
||||||
|
"search query terms.\nDo not include any special characters like "
|
||||||
|
"'+'.\nIf the question is not in English, rewrite the query in "
|
||||||
|
"the language used in the question.\n If the question contains enough "
|
||||||
|
"information, return just the number 1\n If it's unnecessary to do "
|
||||||
|
"the searching, return just the number 0."
|
||||||
|
),
|
||||||
|
HumanMessage(content="How did crypto do last year?"),
|
||||||
|
AIMessage(
|
||||||
|
content="Summarize Cryptocurrency Market Dynamics from last year"
|
||||||
|
),
|
||||||
|
HumanMessage(content="What are my health plans?"),
|
||||||
|
AIMessage(content="Show available health plans"),
|
||||||
|
]
|
||||||
|
for human, ai in history[-self.n_last_interactions :]:
|
||||||
|
messages.append(HumanMessage(content=human))
|
||||||
|
messages.append(AIMessage(content=ai))
|
||||||
|
|
||||||
|
messages.append(HumanMessage(content=f"Generate search query for: {question}"))
|
||||||
|
|
||||||
|
resp = self.llm(messages).text
|
||||||
|
if resp == "0":
|
||||||
|
return Document(content="")
|
||||||
|
|
||||||
|
if resp == "1":
|
||||||
|
return Document(content=question)
|
||||||
|
|
||||||
|
return Document(content=resp)
|
||||||
|
|
||||||
|
|
||||||
class FullQAPipeline(BaseReasoning):
|
class FullQAPipeline(BaseReasoning):
|
||||||
"""Question answering pipeline. Handle from question to answer"""
|
"""Question answering pipeline. Handle from question to answer"""
|
||||||
|
|
||||||
|
@ -417,13 +475,29 @@ class FullQAPipeline(BaseReasoning):
|
||||||
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
||||||
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
||||||
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
||||||
|
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()
|
||||||
|
trigger_context: int = 150
|
||||||
use_rewrite: bool = False
|
use_rewrite: bool = False
|
||||||
|
|
||||||
def retrieve(self, message: str) -> tuple[list[RetrievedDocument], list[Document]]:
|
def retrieve(
|
||||||
|
self, message: str, history: list
|
||||||
|
) -> tuple[list[RetrievedDocument], list[Document]]:
|
||||||
"""Retrieve the documents based on the message"""
|
"""Retrieve the documents based on the message"""
|
||||||
|
if len(message) < self.trigger_context:
|
||||||
|
# prefer adding context for short user questions, avoid adding context for
|
||||||
|
# long questions, as they are likely to contain enough information
|
||||||
|
# plus, avoid the situation where the original message is already too long
|
||||||
|
# for the model to handle
|
||||||
|
query = self.add_query_context(message, history).content
|
||||||
|
else:
|
||||||
|
query = message
|
||||||
|
print(f"Rewritten query: {query}")
|
||||||
|
if not query:
|
||||||
|
return [], []
|
||||||
|
|
||||||
docs, doc_ids = [], []
|
docs, doc_ids = [], []
|
||||||
for retriever in self.retrievers:
|
for retriever in self.retrievers:
|
||||||
for doc in retriever(text=message):
|
for doc in retriever(text=query):
|
||||||
if doc.doc_id not in doc_ids:
|
if doc.doc_id not in doc_ids:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
doc_ids.append(doc.doc_id)
|
doc_ids.append(doc.doc_id)
|
||||||
|
@ -522,7 +596,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
rewrite = await self.rewrite_pipeline(question=message)
|
rewrite = await self.rewrite_pipeline(question=message)
|
||||||
message = rewrite.text
|
message = rewrite.text
|
||||||
|
|
||||||
docs, infos = self.retrieve(message)
|
docs, infos = self.retrieve(message, history)
|
||||||
for _ in infos:
|
for _ in infos:
|
||||||
self.report_output(_)
|
self.report_output(_)
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
@ -564,7 +638,8 @@ class FullQAPipeline(BaseReasoning):
|
||||||
if self.use_rewrite:
|
if self.use_rewrite:
|
||||||
message = self.rewrite_pipeline(question=message).text
|
message = self.rewrite_pipeline(question=message).text
|
||||||
|
|
||||||
docs, infos = self.retrieve(message)
|
# should populate the context
|
||||||
|
docs, infos = self.retrieve(message, history)
|
||||||
for _ in infos:
|
for _ in infos:
|
||||||
yield _
|
yield _
|
||||||
|
|
||||||
|
@ -604,24 +679,27 @@ class FullQAPipeline(BaseReasoning):
|
||||||
settings: the settings for the pipeline
|
settings: the settings for the pipeline
|
||||||
retrievers: the retrievers to use
|
retrievers: the retrievers to use
|
||||||
"""
|
"""
|
||||||
_id = cls.get_info()["id"]
|
prefix = f"reasoning.options.{cls.get_info()['id']}"
|
||||||
|
|
||||||
pipeline = FullQAPipeline(retrievers=retrievers)
|
pipeline = FullQAPipeline(retrievers=retrievers)
|
||||||
pipeline.answering_pipeline.llm = llms.get_default()
|
|
||||||
pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default()
|
|
||||||
|
|
||||||
pipeline.answering_pipeline.enable_citation = settings[
|
# answering pipeline configuration
|
||||||
f"reasoning.options.{_id}.highlight_citation"
|
answer_pipeline = pipeline.answering_pipeline
|
||||||
]
|
answer_pipeline.llm = llms.get_default()
|
||||||
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
answer_pipeline.citation_pipeline.llm = llms.get_default()
|
||||||
|
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
||||||
|
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
||||||
|
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
||||||
|
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
|
||||||
|
answer_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||||
settings["reasoning.lang"], "English"
|
settings["reasoning.lang"], "English"
|
||||||
)
|
)
|
||||||
pipeline.answering_pipeline.system_prompt = settings[
|
|
||||||
f"reasoning.options.{_id}.system_prompt"
|
pipeline.add_query_context.llm = llms.get_default()
|
||||||
]
|
pipeline.add_query_context.n_last_interactions = settings[
|
||||||
pipeline.answering_pipeline.qa_template = settings[
|
f"{prefix}.n_last_interactions"
|
||||||
f"reasoning.options.{_id}.qa_prompt"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
|
||||||
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||||
pipeline.rewrite_pipeline.llm = llms.get_default()
|
pipeline.rewrite_pipeline.llm = llms.get_default()
|
||||||
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||||
|
@ -645,6 +723,21 @@ class FullQAPipeline(BaseReasoning):
|
||||||
"name": "QA Prompt (contains {context}, {question}, {lang})",
|
"name": "QA Prompt (contains {context}, {question}, {lang})",
|
||||||
"value": DEFAULT_QA_TEXT_PROMPT,
|
"value": DEFAULT_QA_TEXT_PROMPT,
|
||||||
},
|
},
|
||||||
|
"n_last_interactions": {
|
||||||
|
"name": "Number of interactions to include",
|
||||||
|
"value": 5,
|
||||||
|
"component": "number",
|
||||||
|
"info": "The maximum number of chat interactions to include in the LLM",
|
||||||
|
},
|
||||||
|
"trigger_context": {
|
||||||
|
"name": "Maximum message length for context rewriting",
|
||||||
|
"value": 150,
|
||||||
|
"component": "number",
|
||||||
|
"info": (
|
||||||
|
"The maximum length of the message to trigger context addition. "
|
||||||
|
"Exceeding this length, the message will be used as is."
|
||||||
|
),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue
Block a user