diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py
index 4e6f7b8..9acd39f 100644
--- a/libs/kotaemon/kotaemon/base/component.py
+++ b/libs/kotaemon/kotaemon/base/component.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import Iterator, Optional
+from typing import AsyncGenerator, Iterator, Optional
from theflow import Function, Node, Param, lazy
@@ -43,6 +43,18 @@ class BaseComponent(Function):
if self._queue is not None:
self._queue.put_nowait(output)
+ def invoke(self, *args, **kwargs) -> Document | list[Document] | None:
+ ...
+
+ async def ainvoke(self, *args, **kwargs) -> Document | list[Document] | None:
+ ...
+
+ def stream(self, *args, **kwargs) -> Iterator[Document] | None:
+ ...
+
+ async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
+ ...
+
@abstractmethod
def run(
self, *args, **kwargs
diff --git a/libs/kotaemon/kotaemon/indices/qa/citation.py b/libs/kotaemon/kotaemon/indices/qa/citation.py
index 4c1281a..4fe8600 100644
--- a/libs/kotaemon/kotaemon/indices/qa/citation.py
+++ b/libs/kotaemon/kotaemon/indices/qa/citation.py
@@ -65,6 +65,9 @@ class CitationPipeline(BaseComponent):
llm: BaseLLM
def run(self, context: str, question: str):
+ return self.invoke(context, question)
+
+ def prepare_llm(self, context: str, question: str):
schema = QuestionAnswer.schema()
function = {
"name": schema["title"],
@@ -92,8 +95,37 @@ class CitationPipeline(BaseComponent):
)
),
]
+ return messages, llm_kwargs
+
+ def invoke(self, context: str, question: str):
+ messages, llm_kwargs = self.prepare_llm(context, question)
+
+ try:
+ print("CitationPipeline: invoking LLM")
+ llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
+ print("CitationPipeline: finish invoking LLM")
+ except Exception as e:
+ print(e)
+ return None
+
+ function_output = llm_output.messages[0].additional_kwargs["function_call"][
+ "arguments"
+ ]
+ output = QuestionAnswer.parse_raw(function_output)
+
+ return output
+
+ async def ainvoke(self, context: str, question: str):
+ messages, llm_kwargs = self.prepare_llm(context, question)
+
+ try:
+ print("CitationPipeline: async invoking LLM")
+ llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
+ print("CitationPipeline: finish async invoking LLM")
+ except Exception as e:
+ print(e)
+ return None
- llm_output = self.llm(messages, **llm_kwargs)
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
diff --git a/libs/kotaemon/kotaemon/llms/base.py b/libs/kotaemon/kotaemon/llms/base.py
index ff315ea..6ef7afc 100644
--- a/libs/kotaemon/kotaemon/llms/base.py
+++ b/libs/kotaemon/kotaemon/llms/base.py
@@ -1,8 +1,22 @@
+from typing import AsyncGenerator, Iterator
+
from langchain_core.language_models.base import BaseLanguageModel
-from kotaemon.base import BaseComponent
+from kotaemon.base import BaseComponent, LLMInterface
class BaseLLM(BaseComponent):
def to_langchain_format(self) -> BaseLanguageModel:
raise NotImplementedError
+
+ def invoke(self, *args, **kwargs) -> LLMInterface:
+ raise NotImplementedError
+
+ async def ainvoke(self, *args, **kwargs) -> LLMInterface:
+ raise NotImplementedError
+
+ def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
+ raise NotImplementedError
+
+ async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
+ raise NotImplementedError
diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py
index c5c2469..14064ba 100644
--- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py
+++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
+from typing import AsyncGenerator, Iterator
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
@@ -10,6 +11,8 @@ logger = logging.getLogger(__name__)
class LCChatMixin:
+ """Mixin for langchain based chat models"""
+
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
@@ -30,18 +33,7 @@ class LCChatMixin:
return self.stream(messages, **kwargs) # type: ignore
return self.invoke(messages, **kwargs)
- def invoke(
- self, messages: str | BaseMessage | list[BaseMessage], **kwargs
- ) -> LLMInterface:
- """Generate response from messages
-
- Args:
- messages: history of messages to generate response from
- **kwargs: additional arguments to pass to the langchain chat model
-
- Returns:
- LLMInterface: generated response
- """
+ def prepare_message(self, messages: str | BaseMessage | list[BaseMessage]):
input_: list[BaseMessage] = []
if isinstance(messages, str):
@@ -51,7 +43,9 @@ class LCChatMixin:
else:
input_ = messages
- pred = self._obj.generate(messages=[input_], **kwargs)
+ return input_
+
+ def prepare_response(self, pred):
all_text = [each.text for each in pred.generations[0]]
all_messages = [each.message for each in pred.generations[0]]
@@ -76,10 +70,41 @@ class LCChatMixin:
logits=[],
)
- def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs):
+ def invoke(
+ self, messages: str | BaseMessage | list[BaseMessage], **kwargs
+ ) -> LLMInterface:
+ """Generate response from messages
+
+ Args:
+ messages: history of messages to generate response from
+ **kwargs: additional arguments to pass to the langchain chat model
+
+ Returns:
+ LLMInterface: generated response
+ """
+ input_ = self.prepare_message(messages)
+ pred = self._obj.generate(messages=[input_], **kwargs)
+ return self.prepare_response(pred)
+
+ async def ainvoke(
+ self, messages: str | BaseMessage | list[BaseMessage], **kwargs
+ ) -> LLMInterface:
+ input_ = self.prepare_message(messages)
+ pred = await self._obj.agenerate(messages=[input_], **kwargs)
+ return self.prepare_response(pred)
+
+ def stream(
+ self, messages: str | BaseMessage | list[BaseMessage], **kwargs
+ ) -> Iterator[LLMInterface]:
for response in self._obj.stream(input=messages, **kwargs):
yield LLMInterface(content=response.content)
+ async def astream(
+ self, messages: str | BaseMessage | list[BaseMessage], **kwargs
+ ) -> AsyncGenerator[LLMInterface, None]:
+ async for response in self._obj.astream(input=messages, **kwargs):
+ yield LLMInterface(content=response.content)
+
def to_langchain_format(self):
return self._obj
@@ -140,7 +165,7 @@ class LCChatMixin:
raise ValueError(f"Invalid param {path}")
-class AzureChatOpenAI(LCChatMixin, ChatLLM):
+class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
def __init__(
self,
azure_endpoint: str | None = None,
diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py
index b06795d..6648c2f 100644
--- a/libs/ktem/ktem/pages/chat/__init__.py
+++ b/libs/ktem/ktem/pages/chat/__init__.py
@@ -209,7 +209,10 @@ class ChatPage(BasePage):
if "output" in response:
text += response["output"]
if "evidence" in response:
- refs += response["evidence"]
+ if response["evidence"] is None:
+ refs = ""
+ else:
+ refs += response["evidence"]
if len(refs) > len_ref:
print(f"Len refs: {len(refs)}")
diff --git a/libs/ktem/ktem/reasoning/base.py b/libs/ktem/ktem/reasoning/base.py
index c122dfa..80cf016 100644
--- a/libs/ktem/ktem/reasoning/base.py
+++ b/libs/ktem/ktem/reasoning/base.py
@@ -1,5 +1,49 @@
+from typing import Optional
+
from kotaemon.base import BaseComponent
class BaseReasoning(BaseComponent):
- retrievers: list = []
+ """The reasoning pipeline that handles each of the user chat messages
+
+ This reasoning pipeline has access to:
+ - the retrievers
+ - the user settings
+ - the message
+ - the conversation id
+ - the message history
+ """
+
+ @classmethod
+ def get_info(cls) -> dict:
+ """Get the pipeline information for the app to organize and display
+
+ Returns:
+ a dictionary that contains the following keys:
+ - "id": the unique id of the pipeline
+ - "name": the human-friendly name of the pipeline
+ - "description": the overview short description of the pipeline, for
+ user to grasp what does the pipeline do
+ """
+ raise NotImplementedError
+
+ @classmethod
+ def get_user_settings(cls) -> dict:
+ """Get the default user settings for this pipeline"""
+ return {}
+
+ @classmethod
+ def get_pipeline(
+ cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None
+ ) -> "BaseReasoning":
+ """Get the reasoning pipeline for the app to execute
+
+ Args:
+ user_setting: user settings
+ retrievers (list): List of retrievers
+ """
+ return cls()
+
+ def run(self, message: str, conv_id: str, history: list, **kwargs): # type: ignore
+ """Execute the reasoning pipeline"""
+ raise NotImplementedError
diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py
index c8653a7..acd768f 100644
--- a/libs/ktem/ktem/reasoning/simple.py
+++ b/libs/ktem/ktem/reasoning/simple.py
@@ -200,22 +200,24 @@ class AnswerWithContextPipeline(BaseComponent):
lang=self.lang,
)
+ citation_task = asyncio.create_task(
+ self.citation_pipeline.ainvoke(context=evidence, question=question)
+ )
+ print("Citation task created")
+
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt))
output = ""
- for text in self.llm(messages):
+ for text in self.llm.stream(messages):
output += text.text
self.report_output({"output": text.text})
await asyncio.sleep(0)
- try:
- citation = self.citation_pipeline(context=evidence, question=question)
- except Exception as e:
- print(e)
- citation = None
-
+ # retrieve the citation
+ print("Waiting for citation task")
+ citation = await citation_task
answer = Document(text=output, metadata={"citation": citation})
return answer
@@ -242,6 +244,19 @@ class FullQAPipeline(BaseReasoning):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
+ for doc in docs:
+ self.report_output(
+ {
+ "evidence": (
+ ""
+ f"{doc.metadata['file_name']}
"
+ f"{doc.text}"
+ "
"
+ )
+ }
+ )
+ await asyncio.sleep(0.1)
+
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = await self.answering_pipeline(
question=message, evidence=evidence, evidence_mode=evidence_mode
@@ -266,6 +281,7 @@ class FullQAPipeline(BaseReasoning):
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
not_detected = set(id2docs.keys()) - set(spans.keys())
+ self.report_output({"evidence": None})
for id, ss in spans.items():
if not ss:
not_detected.add(id)
@@ -282,7 +298,7 @@ class FullQAPipeline(BaseReasoning):
self.report_output(
{
"evidence": (
- ""
+ ""
f"{id2docs[id].metadata['file_name']}
"
f"{text}"
"
"