Improve behavior of simple reasoning (#157)

* Add base reasoning implementation

* Provide explicit async and streaming capability

* Allow refreshing the information panel
This commit is contained in:
Duc Nguyen (john) 2024-03-12 13:03:38 +07:00 committed by GitHub
parent cb01d27d19
commit 2950e6ed02
7 changed files with 174 additions and 28 deletions

View File

@ -1,5 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Iterator, Optional from typing import AsyncGenerator, Iterator, Optional
from theflow import Function, Node, Param, lazy from theflow import Function, Node, Param, lazy
@ -43,6 +43,18 @@ class BaseComponent(Function):
if self._queue is not None: if self._queue is not None:
self._queue.put_nowait(output) self._queue.put_nowait(output)
def invoke(self, *args, **kwargs) -> Document | list[Document] | None:
...
async def ainvoke(self, *args, **kwargs) -> Document | list[Document] | None:
...
def stream(self, *args, **kwargs) -> Iterator[Document] | None:
...
async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
...
@abstractmethod @abstractmethod
def run( def run(
self, *args, **kwargs self, *args, **kwargs

View File

@ -65,6 +65,9 @@ class CitationPipeline(BaseComponent):
llm: BaseLLM llm: BaseLLM
def run(self, context: str, question: str): def run(self, context: str, question: str):
return self.invoke(context, question)
def prepare_llm(self, context: str, question: str):
schema = QuestionAnswer.schema() schema = QuestionAnswer.schema()
function = { function = {
"name": schema["title"], "name": schema["title"],
@ -92,8 +95,37 @@ class CitationPipeline(BaseComponent):
) )
), ),
] ]
return messages, llm_kwargs
def invoke(self, context: str, question: str):
messages, llm_kwargs = self.prepare_llm(context, question)
try:
print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM")
except Exception as e:
print(e)
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
return output
async def ainvoke(self, context: str, question: str):
messages, llm_kwargs = self.prepare_llm(context, question)
try:
print("CitationPipeline: async invoking LLM")
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
print("CitationPipeline: finish async invoking LLM")
except Exception as e:
print(e)
return None
llm_output = self.llm(messages, **llm_kwargs)
function_output = llm_output.messages[0].additional_kwargs["function_call"][ function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments" "arguments"
] ]

View File

@ -1,8 +1,22 @@
from typing import AsyncGenerator, Iterator
from langchain_core.language_models.base import BaseLanguageModel from langchain_core.language_models.base import BaseLanguageModel
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent, LLMInterface
class BaseLLM(BaseComponent): class BaseLLM(BaseComponent):
def to_langchain_format(self) -> BaseLanguageModel: def to_langchain_format(self) -> BaseLanguageModel:
raise NotImplementedError raise NotImplementedError
def invoke(self, *args, **kwargs) -> LLMInterface:
raise NotImplementedError
async def ainvoke(self, *args, **kwargs) -> LLMInterface:
raise NotImplementedError
def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
raise NotImplementedError
async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
raise NotImplementedError

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import AsyncGenerator, Iterator
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
@ -10,6 +11,8 @@ logger = logging.getLogger(__name__)
class LCChatMixin: class LCChatMixin:
"""Mixin for langchain based chat models"""
def _get_lc_class(self): def _get_lc_class(self):
raise NotImplementedError( raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class" "Please return the relevant Langchain class in in _get_lc_class"
@ -30,18 +33,7 @@ class LCChatMixin:
return self.stream(messages, **kwargs) # type: ignore return self.stream(messages, **kwargs) # type: ignore
return self.invoke(messages, **kwargs) return self.invoke(messages, **kwargs)
def invoke( def prepare_message(self, messages: str | BaseMessage | list[BaseMessage]):
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Generate response from messages
Args:
messages: history of messages to generate response from
**kwargs: additional arguments to pass to the langchain chat model
Returns:
LLMInterface: generated response
"""
input_: list[BaseMessage] = [] input_: list[BaseMessage] = []
if isinstance(messages, str): if isinstance(messages, str):
@ -51,7 +43,9 @@ class LCChatMixin:
else: else:
input_ = messages input_ = messages
pred = self._obj.generate(messages=[input_], **kwargs) return input_
def prepare_response(self, pred):
all_text = [each.text for each in pred.generations[0]] all_text = [each.text for each in pred.generations[0]]
all_messages = [each.message for each in pred.generations[0]] all_messages = [each.message for each in pred.generations[0]]
@ -76,10 +70,41 @@ class LCChatMixin:
logits=[], logits=[],
) )
def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs): def invoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Generate response from messages
Args:
messages: history of messages to generate response from
**kwargs: additional arguments to pass to the langchain chat model
Returns:
LLMInterface: generated response
"""
input_ = self.prepare_message(messages)
pred = self._obj.generate(messages=[input_], **kwargs)
return self.prepare_response(pred)
async def ainvoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
input_ = self.prepare_message(messages)
pred = await self._obj.agenerate(messages=[input_], **kwargs)
return self.prepare_response(pred)
def stream(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> Iterator[LLMInterface]:
for response in self._obj.stream(input=messages, **kwargs): for response in self._obj.stream(input=messages, **kwargs):
yield LLMInterface(content=response.content) yield LLMInterface(content=response.content)
async def astream(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> AsyncGenerator[LLMInterface, None]:
async for response in self._obj.astream(input=messages, **kwargs):
yield LLMInterface(content=response.content)
def to_langchain_format(self): def to_langchain_format(self):
return self._obj return self._obj
@ -140,7 +165,7 @@ class LCChatMixin:
raise ValueError(f"Invalid param {path}") raise ValueError(f"Invalid param {path}")
class AzureChatOpenAI(LCChatMixin, ChatLLM): class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
def __init__( def __init__(
self, self,
azure_endpoint: str | None = None, azure_endpoint: str | None = None,

View File

@ -209,7 +209,10 @@ class ChatPage(BasePage):
if "output" in response: if "output" in response:
text += response["output"] text += response["output"]
if "evidence" in response: if "evidence" in response:
refs += response["evidence"] if response["evidence"] is None:
refs = ""
else:
refs += response["evidence"]
if len(refs) > len_ref: if len(refs) > len_ref:
print(f"Len refs: {len(refs)}") print(f"Len refs: {len(refs)}")

View File

@ -1,5 +1,49 @@
from typing import Optional
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
class BaseReasoning(BaseComponent): class BaseReasoning(BaseComponent):
retrievers: list = [] """The reasoning pipeline that handles each of the user chat messages
This reasoning pipeline has access to:
- the retrievers
- the user settings
- the message
- the conversation id
- the message history
"""
@classmethod
def get_info(cls) -> dict:
"""Get the pipeline information for the app to organize and display
Returns:
a dictionary that contains the following keys:
- "id": the unique id of the pipeline
- "name": the human-friendly name of the pipeline
- "description": the overview short description of the pipeline, for
user to grasp what does the pipeline do
"""
raise NotImplementedError
@classmethod
def get_user_settings(cls) -> dict:
"""Get the default user settings for this pipeline"""
return {}
@classmethod
def get_pipeline(
cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None
) -> "BaseReasoning":
"""Get the reasoning pipeline for the app to execute
Args:
user_setting: user settings
retrievers (list): List of retrievers
"""
return cls()
def run(self, message: str, conv_id: str, history: list, **kwargs): # type: ignore
"""Execute the reasoning pipeline"""
raise NotImplementedError

View File

@ -200,22 +200,24 @@ class AnswerWithContextPipeline(BaseComponent):
lang=self.lang, lang=self.lang,
) )
citation_task = asyncio.create_task(
self.citation_pipeline.ainvoke(context=evidence, question=question)
)
print("Citation task created")
messages = [] messages = []
if self.system_prompt: if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt)) messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt)) messages.append(HumanMessage(content=prompt))
output = "" output = ""
for text in self.llm(messages): for text in self.llm.stream(messages):
output += text.text output += text.text
self.report_output({"output": text.text}) self.report_output({"output": text.text})
await asyncio.sleep(0) await asyncio.sleep(0)
try: # retrieve the citation
citation = self.citation_pipeline(context=evidence, question=question) print("Waiting for citation task")
except Exception as e: citation = await citation_task
print(e)
citation = None
answer = Document(text=output, metadata={"citation": citation}) answer = Document(text=output, metadata={"citation": citation})
return answer return answer
@ -242,6 +244,19 @@ class FullQAPipeline(BaseReasoning):
if doc.doc_id not in doc_ids: if doc.doc_id not in doc_ids:
docs.append(doc) docs.append(doc)
doc_ids.append(doc.doc_id) doc_ids.append(doc.doc_id)
for doc in docs:
self.report_output(
{
"evidence": (
"<details open>"
f"<summary>{doc.metadata['file_name']}</summary>"
f"{doc.text}"
"</details><br>"
)
}
)
await asyncio.sleep(0.1)
evidence_mode, evidence = self.evidence_pipeline(docs).content evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = await self.answering_pipeline( answer = await self.answering_pipeline(
question=message, evidence=evidence, evidence_mode=evidence_mode question=message, evidence=evidence, evidence_mode=evidence_mode
@ -266,6 +281,7 @@ class FullQAPipeline(BaseReasoning):
id2docs = {doc.doc_id: doc for doc in docs} id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True lack_evidence = True
not_detected = set(id2docs.keys()) - set(spans.keys()) not_detected = set(id2docs.keys()) - set(spans.keys())
self.report_output({"evidence": None})
for id, ss in spans.items(): for id, ss in spans.items():
if not ss: if not ss:
not_detected.add(id) not_detected.add(id)
@ -282,7 +298,7 @@ class FullQAPipeline(BaseReasoning):
self.report_output( self.report_output(
{ {
"evidence": ( "evidence": (
"<details>" "<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>" f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text}" f"{text}"
"</details><br>" "</details><br>"