diff --git a/knowledgehub/base/schema.py b/knowledgehub/base/schema.py index 648f5d0..3da62b3 100644 --- a/knowledgehub/base/schema.py +++ b/knowledgehub/base/schema.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Optional, TypeVar +from langchain.schema.messages import AIMessage from llama_index.bridge.pydantic import Field from llama_index.schema import Document as BaseDocument @@ -82,6 +83,7 @@ class LLMInterface(Document): total_tokens: int = -1 prompt_tokens: int = -1 logits: list[list[float]] = Field(default_factory=list) + messages: list[AIMessage] = Field(default_factory=list) class ExtractorOutput(Document): diff --git a/knowledgehub/llms/chats/base.py b/knowledgehub/llms/chats/base.py index beed9f6..664546f 100644 --- a/knowledgehub/llms/chats/base.py +++ b/knowledgehub/llms/chats/base.py @@ -70,6 +70,7 @@ class LangchainChatLLM(ChatLLM): pred = self.agent.generate(messages=[input_], **kwargs) all_text = [each.text for each in pred.generations[0]] + all_messages = [each.message for each in pred.generations[0]] completion_tokens, total_tokens, prompt_tokens = 0, 0, 0 try: @@ -88,6 +89,7 @@ class LangchainChatLLM(ChatLLM): completion_tokens=completion_tokens, total_tokens=total_tokens, prompt_tokens=prompt_tokens, + messages=all_messages, logits=[], ) diff --git a/knowledgehub/pipelines/agents/react/agent.py b/knowledgehub/pipelines/agents/react/agent.py index d900f35..0063be3 100644 --- a/knowledgehub/pipelines/agents/react/agent.py +++ b/knowledgehub/pipelines/agents/react/agent.py @@ -4,9 +4,10 @@ from typing import Dict, List, Optional, Tuple, Type, Union from pydantic import BaseModel, create_model +from kotaemon.base.schema import Document from kotaemon.llms import PromptTemplate -from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool +from ..base import AgentType, BaseAgent, BaseLLM, BaseTool from ..output.base import AgentAction, AgentFinish FINAL_ANSWER_ACTION = "Final Answer:" @@ -183,6 +184,11 @@ class ReactAgent(BaseAgent): if is_finished_chain: break - return AgentOutput( - output=response_text, cost=total_cost, token_usage=total_token + return Document( + text=response_text, + metadata={ + "agent": "react", + "cost": total_cost, + "usage": total_token, + }, ) diff --git a/knowledgehub/pipelines/agents/rewoo/agent.py b/knowledgehub/pipelines/agents/rewoo/agent.py index 78d1119..1dcc5ae 100644 --- a/knowledgehub/pipelines/agents/rewoo/agent.py +++ b/knowledgehub/pipelines/agents/rewoo/agent.py @@ -5,7 +5,9 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union from pydantic import BaseModel, create_model +from kotaemon.base.schema import Document from kotaemon.llms import LLM, ChatLLM, PromptTemplate +from kotaemon.pipelines.citation import CitationPipeline from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool from ..output.base import BaseScratchPad @@ -28,7 +30,7 @@ class RewooAgent(BaseAgent): plugins: List[BaseTool] = list() examples: Dict[str, Union[str, List[str]]] = dict() args_schema: Optional[Type[BaseModel]] = create_model( - "ReactArgsSchema", instruction=(str, ...) + "RewooArgsSchema", instruction=(str, ...) ) def _get_llms(self): @@ -218,7 +220,7 @@ class RewooAgent(BaseAgent): if p.name == name: return p - def _run_tool(self, instruction: str) -> AgentOutput: + def _run_tool(self, instruction: str, use_citation: bool = False) -> Document: """ Run the agent with a given instruction. """ @@ -262,7 +264,18 @@ class RewooAgent(BaseAgent): # Solve solver_output = solver(instruction, worker_log) solver_output_text = solver_output.text + if use_citation: + citation_pipeline = CitationPipeline(llm=solver_llm) + citation = citation_pipeline(context=worker_log, question=instruction) + else: + citation = None - return AgentOutput( - output=solver_output_text, cost=total_cost, token_usage=total_token + return Document( + text=solver_output_text, + metadata={ + "agent": "react", + "cost": total_cost, + "usage": total_token, + "citation": citation, + }, ) diff --git a/knowledgehub/pipelines/citation.py b/knowledgehub/pipelines/citation.py new file mode 100644 index 0000000..2577360 --- /dev/null +++ b/knowledgehub/pipelines/citation.py @@ -0,0 +1,110 @@ +from typing import Iterator, List, Union + +from langchain.schema.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from kotaemon.base import BaseComponent + +from ..llms.chats.base import ChatLLM +from ..llms.completions.base import LLM + +BaseLLM = Union[ChatLLM, LLM] + + +class FactWithEvidence(BaseModel): + """Class representing a single statement. + + Each fact has a body and a list of sources. + If there are multiple facts make sure to break them apart + such that each one only uses a set of sources that are relevant to it. + """ + + fact: str = Field(..., description="Body of the sentence, as part of a response") + substring_quote: List[str] = Field( + ..., + description=( + "Each source should be a direct quote from the context, " + "as a substring of the original content" + ), + ) + + def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]: + import regex + + minor = quote + major = context + + errs_ = 0 + s = regex.search(f"({minor}){{e<={errs_}}}", major) + while s is None and errs_ <= errs: + errs_ += 1 + s = regex.search(f"({minor}){{e<={errs_}}}", major) + + if s is not None: + yield from s.spans() + + def get_spans(self, context: str) -> Iterator[str]: + for quote in self.substring_quote: + yield from self._get_span(quote, context) + + +class QuestionAnswer(BaseModel): + """A question and its answer as a list of facts each one should have a source. + each sentence contains a body and a list of sources.""" + + question: str = Field(..., description="Question that was asked") + answer: List[FactWithEvidence] = Field( + ..., + description=( + "Body of the answer, each fact should be " + "its separate object with a body and a list of sources" + ), + ) + + +class CitationPipeline(BaseComponent): + """Citation pipeline to extract cited evidences from source + (based on input question)""" + + llm: BaseLLM + + def run( + self, + context: str, + question: str, + ) -> QuestionAnswer: + schema = QuestionAnswer.schema() + function = { + "name": schema["title"], + "description": schema["description"], + "parameters": schema, + } + llm_kwargs = { + "functions": [function], + "function_call": {"name": function["name"]}, + } + messages = [ + SystemMessage( + content=( + "You are a world class algorithm to answer " + "questions with correct and exact citations." + ) + ), + HumanMessage(content="Answer question using the following context"), + HumanMessage(content=context), + HumanMessage(content=f"Question: {question}"), + HumanMessage( + content=( + "Tips: Make sure to cite your sources, " + "and use the exact words from the context." + ) + ), + ] + + llm_output = self.llm(messages, **llm_kwargs) + function_output = llm_output.messages[0].additional_kwargs["function_call"][ + "arguments" + ] + output = QuestionAnswer.parse_raw(function_output) + + return output diff --git a/knowledgehub/pipelines/qa.py b/knowledgehub/pipelines/qa.py index 93a7147..7bb7322 100644 --- a/knowledgehub/pipelines/qa.py +++ b/knowledgehub/pipelines/qa.py @@ -6,11 +6,12 @@ from theflow import Node from theflow.utils.modules import ObjectInitDeclaration as _ from kotaemon.base import BaseComponent -from kotaemon.base.schema import RetrievedDocument +from kotaemon.base.schema import Document, RetrievedDocument from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.llms import PromptTemplate from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.pipelines.agents import BaseAgent +from kotaemon.pipelines.citation import CitationPipeline from kotaemon.pipelines.reranking import BaseRerankingPipeline from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.pipelines.tools import ComponentTool @@ -40,10 +41,10 @@ class QuestionAnsweringPipeline(BaseComponent): ) llm: AzureChatOpenAI = AzureChatOpenAI.withx( - azure_endpoint="https://bleh-dummy-2.openai.azure.com/", + azure_endpoint="https://bleh-dummy.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", ""), - openai_api_version="2023-03-15-preview", - deployment_name="dummy-q2-gpt35", + openai_api_version="2023-07-01-preview", + deployment_name="dummy-q2-16k", temperature=0, request_timeout=60, ) @@ -90,7 +91,7 @@ class QuestionAnsweringPipeline(BaseComponent): ] return "\n\n".join(matched_texts) - def run(self, question: str) -> str: + def run(self, question: str, use_citation: bool = False) -> Document: # retrieve relevant documents as context documents = self.retrieving_pipeline(question, top_k=int(self.retrieval_top_k)) context = self._format_retrieved_context(documents) @@ -102,7 +103,15 @@ class QuestionAnsweringPipeline(BaseComponent): question=question, ) self.log_progress(".prompt", prompt=prompt) - answer = self.llm(prompt).text + answer_text = self.llm(prompt).text + if use_citation: + # run citation pipeline + citation_pipeline = CitationPipeline(llm=self.llm) + citation = citation_pipeline(context=context, question=question) + else: + citation = None + + answer = Document(text=answer_text, metadata={"citation": citation}) return answer @@ -130,6 +139,6 @@ class AgentQAPipeline(QuestionAnsweringPipeline): if search_tool not in self.agent.plugins: self.agent.plugins.append(search_tool) - def run(self, question: str) -> str: - answer = self.agent(question).output + def run(self, question: str, use_citation: bool = False) -> Document: + answer = self.agent(question, use_citation=use_citation) return answer diff --git a/knowledgehub/pipelines/tools/base.py b/knowledgehub/pipelines/tools/base.py index 413b362..5a55912 100644 --- a/knowledgehub/pipelines/tools/base.py +++ b/knowledgehub/pipelines/tools/base.py @@ -103,7 +103,8 @@ class BaseTool(BaseComponent): # TODO (verbose_): Add logging try: tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) - observation = self._run_tool(*tool_args, **tool_kwargs) + call_kwargs = {**kwargs, **tool_kwargs} + observation = self._run_tool(*tool_args, **call_kwargs) except ToolException as e: observation = self._handle_tool_error(e) return observation diff --git a/tests/test_agent.py b/tests/test_agent.py index a8e2eea..74f8fa3 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -163,7 +163,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search): response = agent("Tell me about Cinnamon AI company") openai_completion.assert_called() - assert response.output == FINAL_RESPONSE_TEXT + assert response.text == FINAL_RESPONSE_TEXT @patch( @@ -180,7 +180,7 @@ def test_react_agent(openai_completion, llm, mock_google_search): response = agent("Tell me about Cinnamon AI company") openai_completion.assert_called() - assert response.output == FINAL_RESPONSE_TEXT + assert response.text == FINAL_RESPONSE_TEXT @patch( @@ -224,4 +224,4 @@ def test_react_agent_with_langchain_tools(openai_completion, llm): response = agent("Tell me about Cinnamon AI company") openai_completion.assert_called() - assert response.output == FINAL_RESPONSE_TEXT + assert response.text == FINAL_RESPONSE_TEXT diff --git a/tests/test_citation.py b/tests/test_citation.py new file mode 100644 index 0000000..3fba999 --- /dev/null +++ b/tests/test_citation.py @@ -0,0 +1,61 @@ +# flake8: noqa +from unittest.mock import patch + +import pytest +from openai.types.chat.chat_completion import ChatCompletion + +from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.pipelines.citation import CitationPipeline + +function_output = '{\n "question": "What is the provided _example_ benefits?",\n "answer": [\n {\n "fact": "特約死亡保険金: 被保険者がこの特約の保険期間中に死亡したときに支払います。",\n "substring_quote": ["特約死亡保険金"]\n },\n {\n "fact": "特約特定疾病保険金: 被保険者がこの特約の保険期間中に特定の疾病(悪性新生物(がん)、急性心筋梗塞または脳卒中)により所定の状態に該当したときに支払います。",\n "substring_quote": ["特約特定疾病保険金"]\n },\n {\n "fact": "特約障害保険金: 被保険者がこの特約の保険期間中に傷害もしくは疾病により所定の身体障害の状態に該当したとき、または不慮の事故により所定の身体障害の状態に該当したときに支払います。",\n "substring_quote": ["特約障害保険金"]\n },\n {\n "fact": "特約介護保険金: 被保険者がこの特約の保険期間中に傷害または疾病により所定の要介護状態に該当したときに支払います。",\n "substring_quote": ["特約介護保険金"]\n }\n ]\n}' + +_openai_chat_completion_response = [ + ChatCompletion.parse_obj( + { + "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", + "object": "chat.completion", + "created": 1692338378, + "model": "gpt-35-turbo", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "finish_reason": "function_call", + "message": { + "role": "assistant", + "content": None, + "function_call": { + "arguments": function_output, + "name": "QuestionAnswer", + }, + "tool_calls": None, + }, + } + ], + "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, + } + ) +] + + +@pytest.fixture +def llm(): + return AzureChatOpenAI( + azure_endpoint="https://dummy.openai.azure.com/", + openai_api_key="dummy", + openai_api_version="2023-03-15-preview", + temperature=0, + ) + + +@patch( + "openai.resources.chat.completions.Completions.create", + side_effect=_openai_chat_completion_response, +) +def test_citation(openai_completion, llm): + question = "test query" + context = "document context" + + citation = CitationPipeline(llm=llm) + result = citation(context, question) + assert len(result.answer) == 4