Improve behavior of simple reasoning (#157)
* Add base reasoning implementation * Provide explicit async and streaming capability * Allow refreshing the information panel
This commit is contained in:
parent
cb01d27d19
commit
2950e6ed02
|
@ -1,5 +1,5 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Iterator, Optional
|
from typing import AsyncGenerator, Iterator, Optional
|
||||||
|
|
||||||
from theflow import Function, Node, Param, lazy
|
from theflow import Function, Node, Param, lazy
|
||||||
|
|
||||||
|
@ -43,6 +43,18 @@ class BaseComponent(Function):
|
||||||
if self._queue is not None:
|
if self._queue is not None:
|
||||||
self._queue.put_nowait(output)
|
self._queue.put_nowait(output)
|
||||||
|
|
||||||
|
def invoke(self, *args, **kwargs) -> Document | list[Document] | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def ainvoke(self, *args, **kwargs) -> Document | list[Document] | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def stream(self, *args, **kwargs) -> Iterator[Document] | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(
|
def run(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
|
|
|
@ -65,6 +65,9 @@ class CitationPipeline(BaseComponent):
|
||||||
llm: BaseLLM
|
llm: BaseLLM
|
||||||
|
|
||||||
def run(self, context: str, question: str):
|
def run(self, context: str, question: str):
|
||||||
|
return self.invoke(context, question)
|
||||||
|
|
||||||
|
def prepare_llm(self, context: str, question: str):
|
||||||
schema = QuestionAnswer.schema()
|
schema = QuestionAnswer.schema()
|
||||||
function = {
|
function = {
|
||||||
"name": schema["title"],
|
"name": schema["title"],
|
||||||
|
@ -92,8 +95,37 @@ class CitationPipeline(BaseComponent):
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
return messages, llm_kwargs
|
||||||
|
|
||||||
|
def invoke(self, context: str, question: str):
|
||||||
|
messages, llm_kwargs = self.prepare_llm(context, question)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("CitationPipeline: invoking LLM")
|
||||||
|
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||||
|
print("CitationPipeline: finish invoking LLM")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||||
|
"arguments"
|
||||||
|
]
|
||||||
|
output = QuestionAnswer.parse_raw(function_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def ainvoke(self, context: str, question: str):
|
||||||
|
messages, llm_kwargs = self.prepare_llm(context, question)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("CitationPipeline: async invoking LLM")
|
||||||
|
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
|
||||||
|
print("CitationPipeline: finish async invoking LLM")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
llm_output = self.llm(messages, **llm_kwargs)
|
|
||||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||||
"arguments"
|
"arguments"
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,8 +1,22 @@
|
||||||
|
from typing import AsyncGenerator, Iterator
|
||||||
|
|
||||||
from langchain_core.language_models.base import BaseLanguageModel
|
from langchain_core.language_models.base import BaseLanguageModel
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent, LLMInterface
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(BaseComponent):
|
class BaseLLM(BaseComponent):
|
||||||
def to_langchain_format(self) -> BaseLanguageModel:
|
def to_langchain_format(self) -> BaseLanguageModel:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def invoke(self, *args, **kwargs) -> LLMInterface:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def ainvoke(self, *args, **kwargs) -> LLMInterface:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import AsyncGenerator, Iterator
|
||||||
|
|
||||||
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
|
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
|
||||||
|
|
||||||
|
@ -10,6 +11,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LCChatMixin:
|
class LCChatMixin:
|
||||||
|
"""Mixin for langchain based chat models"""
|
||||||
|
|
||||||
def _get_lc_class(self):
|
def _get_lc_class(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Please return the relevant Langchain class in in _get_lc_class"
|
"Please return the relevant Langchain class in in _get_lc_class"
|
||||||
|
@ -30,18 +33,7 @@ class LCChatMixin:
|
||||||
return self.stream(messages, **kwargs) # type: ignore
|
return self.stream(messages, **kwargs) # type: ignore
|
||||||
return self.invoke(messages, **kwargs)
|
return self.invoke(messages, **kwargs)
|
||||||
|
|
||||||
def invoke(
|
def prepare_message(self, messages: str | BaseMessage | list[BaseMessage]):
|
||||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
|
||||||
) -> LLMInterface:
|
|
||||||
"""Generate response from messages
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: history of messages to generate response from
|
|
||||||
**kwargs: additional arguments to pass to the langchain chat model
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLMInterface: generated response
|
|
||||||
"""
|
|
||||||
input_: list[BaseMessage] = []
|
input_: list[BaseMessage] = []
|
||||||
|
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
|
@ -51,7 +43,9 @@ class LCChatMixin:
|
||||||
else:
|
else:
|
||||||
input_ = messages
|
input_ = messages
|
||||||
|
|
||||||
pred = self._obj.generate(messages=[input_], **kwargs)
|
return input_
|
||||||
|
|
||||||
|
def prepare_response(self, pred):
|
||||||
all_text = [each.text for each in pred.generations[0]]
|
all_text = [each.text for each in pred.generations[0]]
|
||||||
all_messages = [each.message for each in pred.generations[0]]
|
all_messages = [each.message for each in pred.generations[0]]
|
||||||
|
|
||||||
|
@ -76,10 +70,41 @@ class LCChatMixin:
|
||||||
logits=[],
|
logits=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs):
|
def invoke(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||||
|
) -> LLMInterface:
|
||||||
|
"""Generate response from messages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: history of messages to generate response from
|
||||||
|
**kwargs: additional arguments to pass to the langchain chat model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMInterface: generated response
|
||||||
|
"""
|
||||||
|
input_ = self.prepare_message(messages)
|
||||||
|
pred = self._obj.generate(messages=[input_], **kwargs)
|
||||||
|
return self.prepare_response(pred)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||||
|
) -> LLMInterface:
|
||||||
|
input_ = self.prepare_message(messages)
|
||||||
|
pred = await self._obj.agenerate(messages=[input_], **kwargs)
|
||||||
|
return self.prepare_response(pred)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||||
|
) -> Iterator[LLMInterface]:
|
||||||
for response in self._obj.stream(input=messages, **kwargs):
|
for response in self._obj.stream(input=messages, **kwargs):
|
||||||
yield LLMInterface(content=response.content)
|
yield LLMInterface(content=response.content)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||||
|
) -> AsyncGenerator[LLMInterface, None]:
|
||||||
|
async for response in self._obj.astream(input=messages, **kwargs):
|
||||||
|
yield LLMInterface(content=response.content)
|
||||||
|
|
||||||
def to_langchain_format(self):
|
def to_langchain_format(self):
|
||||||
return self._obj
|
return self._obj
|
||||||
|
|
||||||
|
@ -140,7 +165,7 @@ class LCChatMixin:
|
||||||
raise ValueError(f"Invalid param {path}")
|
raise ValueError(f"Invalid param {path}")
|
||||||
|
|
||||||
|
|
||||||
class AzureChatOpenAI(LCChatMixin, ChatLLM):
|
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
azure_endpoint: str | None = None,
|
azure_endpoint: str | None = None,
|
||||||
|
|
|
@ -209,6 +209,9 @@ class ChatPage(BasePage):
|
||||||
if "output" in response:
|
if "output" in response:
|
||||||
text += response["output"]
|
text += response["output"]
|
||||||
if "evidence" in response:
|
if "evidence" in response:
|
||||||
|
if response["evidence"] is None:
|
||||||
|
refs = ""
|
||||||
|
else:
|
||||||
refs += response["evidence"]
|
refs += response["evidence"]
|
||||||
|
|
||||||
if len(refs) > len_ref:
|
if len(refs) > len_ref:
|
||||||
|
|
|
@ -1,5 +1,49 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
|
|
||||||
|
|
||||||
class BaseReasoning(BaseComponent):
|
class BaseReasoning(BaseComponent):
|
||||||
retrievers: list = []
|
"""The reasoning pipeline that handles each of the user chat messages
|
||||||
|
|
||||||
|
This reasoning pipeline has access to:
|
||||||
|
- the retrievers
|
||||||
|
- the user settings
|
||||||
|
- the message
|
||||||
|
- the conversation id
|
||||||
|
- the message history
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_info(cls) -> dict:
|
||||||
|
"""Get the pipeline information for the app to organize and display
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dictionary that contains the following keys:
|
||||||
|
- "id": the unique id of the pipeline
|
||||||
|
- "name": the human-friendly name of the pipeline
|
||||||
|
- "description": the overview short description of the pipeline, for
|
||||||
|
user to grasp what does the pipeline do
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_settings(cls) -> dict:
|
||||||
|
"""Get the default user settings for this pipeline"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline(
|
||||||
|
cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None
|
||||||
|
) -> "BaseReasoning":
|
||||||
|
"""Get the reasoning pipeline for the app to execute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_setting: user settings
|
||||||
|
retrievers (list): List of retrievers
|
||||||
|
"""
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def run(self, message: str, conv_id: str, history: list, **kwargs): # type: ignore
|
||||||
|
"""Execute the reasoning pipeline"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -200,22 +200,24 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
lang=self.lang,
|
lang=self.lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
citation_task = asyncio.create_task(
|
||||||
|
self.citation_pipeline.ainvoke(context=evidence, question=question)
|
||||||
|
)
|
||||||
|
print("Citation task created")
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
messages.append(HumanMessage(content=prompt))
|
messages.append(HumanMessage(content=prompt))
|
||||||
output = ""
|
output = ""
|
||||||
for text in self.llm(messages):
|
for text in self.llm.stream(messages):
|
||||||
output += text.text
|
output += text.text
|
||||||
self.report_output({"output": text.text})
|
self.report_output({"output": text.text})
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
try:
|
# retrieve the citation
|
||||||
citation = self.citation_pipeline(context=evidence, question=question)
|
print("Waiting for citation task")
|
||||||
except Exception as e:
|
citation = await citation_task
|
||||||
print(e)
|
|
||||||
citation = None
|
|
||||||
|
|
||||||
answer = Document(text=output, metadata={"citation": citation})
|
answer = Document(text=output, metadata={"citation": citation})
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
@ -242,6 +244,19 @@ class FullQAPipeline(BaseReasoning):
|
||||||
if doc.doc_id not in doc_ids:
|
if doc.doc_id not in doc_ids:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
doc_ids.append(doc.doc_id)
|
doc_ids.append(doc.doc_id)
|
||||||
|
for doc in docs:
|
||||||
|
self.report_output(
|
||||||
|
{
|
||||||
|
"evidence": (
|
||||||
|
"<details open>"
|
||||||
|
f"<summary>{doc.metadata['file_name']}</summary>"
|
||||||
|
f"{doc.text}"
|
||||||
|
"</details><br>"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||||
answer = await self.answering_pipeline(
|
answer = await self.answering_pipeline(
|
||||||
question=message, evidence=evidence, evidence_mode=evidence_mode
|
question=message, evidence=evidence, evidence_mode=evidence_mode
|
||||||
|
@ -266,6 +281,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
id2docs = {doc.doc_id: doc for doc in docs}
|
id2docs = {doc.doc_id: doc for doc in docs}
|
||||||
lack_evidence = True
|
lack_evidence = True
|
||||||
not_detected = set(id2docs.keys()) - set(spans.keys())
|
not_detected = set(id2docs.keys()) - set(spans.keys())
|
||||||
|
self.report_output({"evidence": None})
|
||||||
for id, ss in spans.items():
|
for id, ss in spans.items():
|
||||||
if not ss:
|
if not ss:
|
||||||
not_detected.add(id)
|
not_detected.add(id)
|
||||||
|
@ -282,7 +298,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
self.report_output(
|
self.report_output(
|
||||||
{
|
{
|
||||||
"evidence": (
|
"evidence": (
|
||||||
"<details>"
|
"<details open>"
|
||||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||||
f"{text}"
|
f"{text}"
|
||||||
"</details><br>"
|
"</details><br>"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user