Add Citation pipeline (#78)

* add rerankers in retrieving pipeline

* update example MVP pipeline

* add citation pipeline and function call interface

* change return type of QA and AgentPipeline to Document
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-11-16 11:24:35 +07:00 committed by GitHub
parent f8b8d86d4e
commit cc1e75b3c6
9 changed files with 223 additions and 19 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, TypeVar from typing import TYPE_CHECKING, Any, Optional, TypeVar
from langchain.schema.messages import AIMessage
from llama_index.bridge.pydantic import Field from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument from llama_index.schema import Document as BaseDocument
@ -82,6 +83,7 @@ class LLMInterface(Document):
total_tokens: int = -1 total_tokens: int = -1
prompt_tokens: int = -1 prompt_tokens: int = -1
logits: list[list[float]] = Field(default_factory=list) logits: list[list[float]] = Field(default_factory=list)
messages: list[AIMessage] = Field(default_factory=list)
class ExtractorOutput(Document): class ExtractorOutput(Document):

View File

@ -70,6 +70,7 @@ class LangchainChatLLM(ChatLLM):
pred = self.agent.generate(messages=[input_], **kwargs) pred = self.agent.generate(messages=[input_], **kwargs)
all_text = [each.text for each in pred.generations[0]] all_text = [each.text for each in pred.generations[0]]
all_messages = [each.message for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0 completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try: try:
@ -88,6 +89,7 @@ class LangchainChatLLM(ChatLLM):
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
messages=all_messages,
logits=[], logits=[],
) )

View File

@ -4,9 +4,10 @@ from typing import Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
from kotaemon.base.schema import Document
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import AgentAction, AgentFinish from ..output.base import AgentAction, AgentFinish
FINAL_ANSWER_ACTION = "Final Answer:" FINAL_ANSWER_ACTION = "Final Answer:"
@ -183,6 +184,11 @@ class ReactAgent(BaseAgent):
if is_finished_chain: if is_finished_chain:
break break
return AgentOutput( return Document(
output=response_text, cost=total_cost, token_usage=total_token text=response_text,
metadata={
"agent": "react",
"cost": total_cost,
"usage": total_token,
},
) )

View File

@ -5,7 +5,9 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
from kotaemon.base.schema import Document
from kotaemon.llms import LLM, ChatLLM, PromptTemplate from kotaemon.llms import LLM, ChatLLM, PromptTemplate
from kotaemon.pipelines.citation import CitationPipeline
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import BaseScratchPad from ..output.base import BaseScratchPad
@ -28,7 +30,7 @@ class RewooAgent(BaseAgent):
plugins: List[BaseTool] = list() plugins: List[BaseTool] = list()
examples: Dict[str, Union[str, List[str]]] = dict() examples: Dict[str, Union[str, List[str]]] = dict()
args_schema: Optional[Type[BaseModel]] = create_model( args_schema: Optional[Type[BaseModel]] = create_model(
"ReactArgsSchema", instruction=(str, ...) "RewooArgsSchema", instruction=(str, ...)
) )
def _get_llms(self): def _get_llms(self):
@ -218,7 +220,7 @@ class RewooAgent(BaseAgent):
if p.name == name: if p.name == name:
return p return p
def _run_tool(self, instruction: str) -> AgentOutput: def _run_tool(self, instruction: str, use_citation: bool = False) -> Document:
""" """
Run the agent with a given instruction. Run the agent with a given instruction.
""" """
@ -262,7 +264,18 @@ class RewooAgent(BaseAgent):
# Solve # Solve
solver_output = solver(instruction, worker_log) solver_output = solver(instruction, worker_log)
solver_output_text = solver_output.text solver_output_text = solver_output.text
if use_citation:
citation_pipeline = CitationPipeline(llm=solver_llm)
citation = citation_pipeline(context=worker_log, question=instruction)
else:
citation = None
return AgentOutput( return Document(
output=solver_output_text, cost=total_cost, token_usage=total_token text=solver_output_text,
metadata={
"agent": "react",
"cost": total_cost,
"usage": total_token,
"citation": citation,
},
) )

View File

@ -0,0 +1,110 @@
from typing import Iterator, List, Union
from langchain.schema.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from kotaemon.base import BaseComponent
from ..llms.chats.base import ChatLLM
from ..llms.completions.base import LLM
BaseLLM = Union[ChatLLM, LLM]
class FactWithEvidence(BaseModel):
"""Class representing a single statement.
Each fact has a body and a list of sources.
If there are multiple facts make sure to break them apart
such that each one only uses a set of sources that are relevant to it.
"""
fact: str = Field(..., description="Body of the sentence, as part of a response")
substring_quote: List[str] = Field(
...,
description=(
"Each source should be a direct quote from the context, "
"as a substring of the original content"
),
)
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
import regex
minor = quote
major = context
errs_ = 0
s = regex.search(f"({minor}){{e<={errs_}}}", major)
while s is None and errs_ <= errs:
errs_ += 1
s = regex.search(f"({minor}){{e<={errs_}}}", major)
if s is not None:
yield from s.spans()
def get_spans(self, context: str) -> Iterator[str]:
for quote in self.substring_quote:
yield from self._get_span(quote, context)
class QuestionAnswer(BaseModel):
"""A question and its answer as a list of facts each one should have a source.
each sentence contains a body and a list of sources."""
question: str = Field(..., description="Question that was asked")
answer: List[FactWithEvidence] = Field(
...,
description=(
"Body of the answer, each fact should be "
"its separate object with a body and a list of sources"
),
)
class CitationPipeline(BaseComponent):
"""Citation pipeline to extract cited evidences from source
(based on input question)"""
llm: BaseLLM
def run(
self,
context: str,
question: str,
) -> QuestionAnswer:
schema = QuestionAnswer.schema()
function = {
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
llm_kwargs = {
"functions": [function],
"function_call": {"name": function["name"]},
}
messages = [
SystemMessage(
content=(
"You are a world class algorithm to answer "
"questions with correct and exact citations."
)
),
HumanMessage(content="Answer question using the following context"),
HumanMessage(content=context),
HumanMessage(content=f"Question: {question}"),
HumanMessage(
content=(
"Tips: Make sure to cite your sources, "
"and use the exact words from the context."
)
),
]
llm_output = self.llm(messages, **llm_kwargs)
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
return output

View File

@ -6,11 +6,12 @@ from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _ from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.base.schema import RetrievedDocument from kotaemon.base.schema import Document, RetrievedDocument
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents import BaseAgent from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.citation import CitationPipeline
from kotaemon.pipelines.reranking import BaseRerankingPipeline from kotaemon.pipelines.reranking import BaseRerankingPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool from kotaemon.pipelines.tools import ComponentTool
@ -40,10 +41,10 @@ class QuestionAnsweringPipeline(BaseComponent):
) )
llm: AzureChatOpenAI = AzureChatOpenAI.withx( llm: AzureChatOpenAI = AzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy-2.openai.azure.com/", azure_endpoint="https://bleh-dummy.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-03-15-preview", openai_api_version="2023-07-01-preview",
deployment_name="dummy-q2-gpt35", deployment_name="dummy-q2-16k",
temperature=0, temperature=0,
request_timeout=60, request_timeout=60,
) )
@ -90,7 +91,7 @@ class QuestionAnsweringPipeline(BaseComponent):
] ]
return "\n\n".join(matched_texts) return "\n\n".join(matched_texts)
def run(self, question: str) -> str: def run(self, question: str, use_citation: bool = False) -> Document:
# retrieve relevant documents as context # retrieve relevant documents as context
documents = self.retrieving_pipeline(question, top_k=int(self.retrieval_top_k)) documents = self.retrieving_pipeline(question, top_k=int(self.retrieval_top_k))
context = self._format_retrieved_context(documents) context = self._format_retrieved_context(documents)
@ -102,7 +103,15 @@ class QuestionAnsweringPipeline(BaseComponent):
question=question, question=question,
) )
self.log_progress(".prompt", prompt=prompt) self.log_progress(".prompt", prompt=prompt)
answer = self.llm(prompt).text answer_text = self.llm(prompt).text
if use_citation:
# run citation pipeline
citation_pipeline = CitationPipeline(llm=self.llm)
citation = citation_pipeline(context=context, question=question)
else:
citation = None
answer = Document(text=answer_text, metadata={"citation": citation})
return answer return answer
@ -130,6 +139,6 @@ class AgentQAPipeline(QuestionAnsweringPipeline):
if search_tool not in self.agent.plugins: if search_tool not in self.agent.plugins:
self.agent.plugins.append(search_tool) self.agent.plugins.append(search_tool)
def run(self, question: str) -> str: def run(self, question: str, use_citation: bool = False) -> Document:
answer = self.agent(question).output answer = self.agent(question, use_citation=use_citation)
return answer return answer

View File

@ -103,7 +103,8 @@ class BaseTool(BaseComponent):
# TODO (verbose_): Add logging # TODO (verbose_): Add logging
try: try:
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = self._run_tool(*tool_args, **tool_kwargs) call_kwargs = {**kwargs, **tool_kwargs}
observation = self._run_tool(*tool_args, **call_kwargs)
except ToolException as e: except ToolException as e:
observation = self._handle_tool_error(e) observation = self._handle_tool_error(e)
return observation return observation

View File

@ -163,7 +163,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search):
response = agent("Tell me about Cinnamon AI company") response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called() openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT assert response.text == FINAL_RESPONSE_TEXT
@patch( @patch(
@ -180,7 +180,7 @@ def test_react_agent(openai_completion, llm, mock_google_search):
response = agent("Tell me about Cinnamon AI company") response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called() openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT assert response.text == FINAL_RESPONSE_TEXT
@patch( @patch(
@ -224,4 +224,4 @@ def test_react_agent_with_langchain_tools(openai_completion, llm):
response = agent("Tell me about Cinnamon AI company") response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called() openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT assert response.text == FINAL_RESPONSE_TEXT

61
tests/test_citation.py Normal file
View File

@ -0,0 +1,61 @@
# flake8: noqa
from unittest.mock import patch
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.citation import CitationPipeline
function_output = '{\n "question": "What is the provided _example_ benefits?",\n "answer": [\n {\n "fact": "特約死亡保険金: 被保険者がこの特約の保険期間中に死亡したときに支払います。",\n "substring_quote": ["特約死亡保険金"]\n },\n {\n "fact": "特約特定疾病保険金: 被保険者がこの特約の保険期間中に特定の疾病(悪性新生物(がん)、急性心筋梗塞または脳卒中)により所定の状態に該当したときに支払います。",\n "substring_quote": ["特約特定疾病保険金"]\n },\n {\n "fact": "特約障害保険金: 被保険者がこの特約の保険期間中に傷害もしくは疾病により所定の身体障害の状態に該当したとき、または不慮の事故により所定の身体障害の状態に該当したときに支払います。",\n "substring_quote": ["特約障害保険金"]\n },\n {\n "fact": "特約介護保険金: 被保険者がこの特約の保険期間中に傷害または疾病により所定の要介護状態に該当したときに支払います。",\n "substring_quote": ["特約介護保険金"]\n }\n ]\n}'
_openai_chat_completion_response = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "function_call",
"message": {
"role": "assistant",
"content": None,
"function_call": {
"arguments": function_output,
"name": "QuestionAnswer",
},
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
]
@pytest.fixture
def llm():
return AzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
temperature=0,
)
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_response,
)
def test_citation(openai_completion, llm):
question = "test query"
context = "document context"
citation = CitationPipeline(llm=llm)
result = citation(context, question)
assert len(result.answer) == 4