Add Citation pipeline (#78)
* add rerankers in retrieving pipeline * update example MVP pipeline * add citation pipeline and function call interface * change return type of QA and AgentPipeline to Document
This commit is contained in:
parent
f8b8d86d4e
commit
cc1e75b3c6
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||||
|
|
||||||
|
from langchain.schema.messages import AIMessage
|
||||||
from llama_index.bridge.pydantic import Field
|
from llama_index.bridge.pydantic import Field
|
||||||
from llama_index.schema import Document as BaseDocument
|
from llama_index.schema import Document as BaseDocument
|
||||||
|
|
||||||
|
@ -82,6 +83,7 @@ class LLMInterface(Document):
|
||||||
total_tokens: int = -1
|
total_tokens: int = -1
|
||||||
prompt_tokens: int = -1
|
prompt_tokens: int = -1
|
||||||
logits: list[list[float]] = Field(default_factory=list)
|
logits: list[list[float]] = Field(default_factory=list)
|
||||||
|
messages: list[AIMessage] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ExtractorOutput(Document):
|
class ExtractorOutput(Document):
|
||||||
|
|
|
@ -70,6 +70,7 @@ class LangchainChatLLM(ChatLLM):
|
||||||
|
|
||||||
pred = self.agent.generate(messages=[input_], **kwargs)
|
pred = self.agent.generate(messages=[input_], **kwargs)
|
||||||
all_text = [each.text for each in pred.generations[0]]
|
all_text = [each.text for each in pred.generations[0]]
|
||||||
|
all_messages = [each.message for each in pred.generations[0]]
|
||||||
|
|
||||||
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
|
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
|
||||||
try:
|
try:
|
||||||
|
@ -88,6 +89,7 @@ class LangchainChatLLM(ChatLLM):
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
|
messages=all_messages,
|
||||||
logits=[],
|
logits=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,10 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
|
|
||||||
|
from kotaemon.base.schema import Document
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import PromptTemplate
|
||||||
|
|
||||||
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool
|
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
||||||
from ..output.base import AgentAction, AgentFinish
|
from ..output.base import AgentAction, AgentFinish
|
||||||
|
|
||||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||||
|
@ -183,6 +184,11 @@ class ReactAgent(BaseAgent):
|
||||||
if is_finished_chain:
|
if is_finished_chain:
|
||||||
break
|
break
|
||||||
|
|
||||||
return AgentOutput(
|
return Document(
|
||||||
output=response_text, cost=total_cost, token_usage=total_token
|
text=response_text,
|
||||||
|
metadata={
|
||||||
|
"agent": "react",
|
||||||
|
"cost": total_cost,
|
||||||
|
"usage": total_token,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,9 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
|
|
||||||
|
from kotaemon.base.schema import Document
|
||||||
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
||||||
|
from kotaemon.pipelines.citation import CitationPipeline
|
||||||
|
|
||||||
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool
|
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool
|
||||||
from ..output.base import BaseScratchPad
|
from ..output.base import BaseScratchPad
|
||||||
|
@ -28,7 +30,7 @@ class RewooAgent(BaseAgent):
|
||||||
plugins: List[BaseTool] = list()
|
plugins: List[BaseTool] = list()
|
||||||
examples: Dict[str, Union[str, List[str]]] = dict()
|
examples: Dict[str, Union[str, List[str]]] = dict()
|
||||||
args_schema: Optional[Type[BaseModel]] = create_model(
|
args_schema: Optional[Type[BaseModel]] = create_model(
|
||||||
"ReactArgsSchema", instruction=(str, ...)
|
"RewooArgsSchema", instruction=(str, ...)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_llms(self):
|
def _get_llms(self):
|
||||||
|
@ -218,7 +220,7 @@ class RewooAgent(BaseAgent):
|
||||||
if p.name == name:
|
if p.name == name:
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def _run_tool(self, instruction: str) -> AgentOutput:
|
def _run_tool(self, instruction: str, use_citation: bool = False) -> Document:
|
||||||
"""
|
"""
|
||||||
Run the agent with a given instruction.
|
Run the agent with a given instruction.
|
||||||
"""
|
"""
|
||||||
|
@ -262,7 +264,18 @@ class RewooAgent(BaseAgent):
|
||||||
# Solve
|
# Solve
|
||||||
solver_output = solver(instruction, worker_log)
|
solver_output = solver(instruction, worker_log)
|
||||||
solver_output_text = solver_output.text
|
solver_output_text = solver_output.text
|
||||||
|
if use_citation:
|
||||||
|
citation_pipeline = CitationPipeline(llm=solver_llm)
|
||||||
|
citation = citation_pipeline(context=worker_log, question=instruction)
|
||||||
|
else:
|
||||||
|
citation = None
|
||||||
|
|
||||||
return AgentOutput(
|
return Document(
|
||||||
output=solver_output_text, cost=total_cost, token_usage=total_token
|
text=solver_output_text,
|
||||||
|
metadata={
|
||||||
|
"agent": "react",
|
||||||
|
"cost": total_cost,
|
||||||
|
"usage": total_token,
|
||||||
|
"citation": citation,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
110
knowledgehub/pipelines/citation.py
Normal file
110
knowledgehub/pipelines/citation.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
from typing import Iterator, List, Union
|
||||||
|
|
||||||
|
from langchain.schema.messages import HumanMessage, SystemMessage
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
|
||||||
|
from ..llms.chats.base import ChatLLM
|
||||||
|
from ..llms.completions.base import LLM
|
||||||
|
|
||||||
|
BaseLLM = Union[ChatLLM, LLM]
|
||||||
|
|
||||||
|
|
||||||
|
class FactWithEvidence(BaseModel):
|
||||||
|
"""Class representing a single statement.
|
||||||
|
|
||||||
|
Each fact has a body and a list of sources.
|
||||||
|
If there are multiple facts make sure to break them apart
|
||||||
|
such that each one only uses a set of sources that are relevant to it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||||
|
substring_quote: List[str] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Each source should be a direct quote from the context, "
|
||||||
|
"as a substring of the original content"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||||
|
import regex
|
||||||
|
|
||||||
|
minor = quote
|
||||||
|
major = context
|
||||||
|
|
||||||
|
errs_ = 0
|
||||||
|
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||||
|
while s is None and errs_ <= errs:
|
||||||
|
errs_ += 1
|
||||||
|
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||||
|
|
||||||
|
if s is not None:
|
||||||
|
yield from s.spans()
|
||||||
|
|
||||||
|
def get_spans(self, context: str) -> Iterator[str]:
|
||||||
|
for quote in self.substring_quote:
|
||||||
|
yield from self._get_span(quote, context)
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionAnswer(BaseModel):
|
||||||
|
"""A question and its answer as a list of facts each one should have a source.
|
||||||
|
each sentence contains a body and a list of sources."""
|
||||||
|
|
||||||
|
question: str = Field(..., description="Question that was asked")
|
||||||
|
answer: List[FactWithEvidence] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Body of the answer, each fact should be "
|
||||||
|
"its separate object with a body and a list of sources"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CitationPipeline(BaseComponent):
|
||||||
|
"""Citation pipeline to extract cited evidences from source
|
||||||
|
(based on input question)"""
|
||||||
|
|
||||||
|
llm: BaseLLM
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
context: str,
|
||||||
|
question: str,
|
||||||
|
) -> QuestionAnswer:
|
||||||
|
schema = QuestionAnswer.schema()
|
||||||
|
function = {
|
||||||
|
"name": schema["title"],
|
||||||
|
"description": schema["description"],
|
||||||
|
"parameters": schema,
|
||||||
|
}
|
||||||
|
llm_kwargs = {
|
||||||
|
"functions": [function],
|
||||||
|
"function_call": {"name": function["name"]},
|
||||||
|
}
|
||||||
|
messages = [
|
||||||
|
SystemMessage(
|
||||||
|
content=(
|
||||||
|
"You are a world class algorithm to answer "
|
||||||
|
"questions with correct and exact citations."
|
||||||
|
)
|
||||||
|
),
|
||||||
|
HumanMessage(content="Answer question using the following context"),
|
||||||
|
HumanMessage(content=context),
|
||||||
|
HumanMessage(content=f"Question: {question}"),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
"Tips: Make sure to cite your sources, "
|
||||||
|
"and use the exact words from the context."
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
llm_output = self.llm(messages, **llm_kwargs)
|
||||||
|
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||||
|
"arguments"
|
||||||
|
]
|
||||||
|
output = QuestionAnswer.parse_raw(function_output)
|
||||||
|
|
||||||
|
return output
|
|
@ -6,11 +6,12 @@ from theflow import Node
|
||||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
from kotaemon.base.schema import RetrievedDocument
|
from kotaemon.base.schema import Document, RetrievedDocument
|
||||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import PromptTemplate
|
||||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
from kotaemon.pipelines.agents import BaseAgent
|
from kotaemon.pipelines.agents import BaseAgent
|
||||||
|
from kotaemon.pipelines.citation import CitationPipeline
|
||||||
from kotaemon.pipelines.reranking import BaseRerankingPipeline
|
from kotaemon.pipelines.reranking import BaseRerankingPipeline
|
||||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||||
from kotaemon.pipelines.tools import ComponentTool
|
from kotaemon.pipelines.tools import ComponentTool
|
||||||
|
@ -40,10 +41,10 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
)
|
)
|
||||||
|
|
||||||
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
||||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-07-01-preview",
|
||||||
deployment_name="dummy-q2-gpt35",
|
deployment_name="dummy-q2-16k",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
request_timeout=60,
|
request_timeout=60,
|
||||||
)
|
)
|
||||||
|
@ -90,7 +91,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
]
|
]
|
||||||
return "\n\n".join(matched_texts)
|
return "\n\n".join(matched_texts)
|
||||||
|
|
||||||
def run(self, question: str) -> str:
|
def run(self, question: str, use_citation: bool = False) -> Document:
|
||||||
# retrieve relevant documents as context
|
# retrieve relevant documents as context
|
||||||
documents = self.retrieving_pipeline(question, top_k=int(self.retrieval_top_k))
|
documents = self.retrieving_pipeline(question, top_k=int(self.retrieval_top_k))
|
||||||
context = self._format_retrieved_context(documents)
|
context = self._format_retrieved_context(documents)
|
||||||
|
@ -102,7 +103,15 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
question=question,
|
question=question,
|
||||||
)
|
)
|
||||||
self.log_progress(".prompt", prompt=prompt)
|
self.log_progress(".prompt", prompt=prompt)
|
||||||
answer = self.llm(prompt).text
|
answer_text = self.llm(prompt).text
|
||||||
|
if use_citation:
|
||||||
|
# run citation pipeline
|
||||||
|
citation_pipeline = CitationPipeline(llm=self.llm)
|
||||||
|
citation = citation_pipeline(context=context, question=question)
|
||||||
|
else:
|
||||||
|
citation = None
|
||||||
|
|
||||||
|
answer = Document(text=answer_text, metadata={"citation": citation})
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,6 +139,6 @@ class AgentQAPipeline(QuestionAnsweringPipeline):
|
||||||
if search_tool not in self.agent.plugins:
|
if search_tool not in self.agent.plugins:
|
||||||
self.agent.plugins.append(search_tool)
|
self.agent.plugins.append(search_tool)
|
||||||
|
|
||||||
def run(self, question: str) -> str:
|
def run(self, question: str, use_citation: bool = False) -> Document:
|
||||||
answer = self.agent(question).output
|
answer = self.agent(question, use_citation=use_citation)
|
||||||
return answer
|
return answer
|
||||||
|
|
|
@ -103,7 +103,8 @@ class BaseTool(BaseComponent):
|
||||||
# TODO (verbose_): Add logging
|
# TODO (verbose_): Add logging
|
||||||
try:
|
try:
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||||
observation = self._run_tool(*tool_args, **tool_kwargs)
|
call_kwargs = {**kwargs, **tool_kwargs}
|
||||||
|
observation = self._run_tool(*tool_args, **call_kwargs)
|
||||||
except ToolException as e:
|
except ToolException as e:
|
||||||
observation = self._handle_tool_error(e)
|
observation = self._handle_tool_error(e)
|
||||||
return observation
|
return observation
|
||||||
|
|
|
@ -163,7 +163,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search):
|
||||||
|
|
||||||
response = agent("Tell me about Cinnamon AI company")
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
openai_completion.assert_called()
|
openai_completion.assert_called()
|
||||||
assert response.output == FINAL_RESPONSE_TEXT
|
assert response.text == FINAL_RESPONSE_TEXT
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
|
@ -180,7 +180,7 @@ def test_react_agent(openai_completion, llm, mock_google_search):
|
||||||
|
|
||||||
response = agent("Tell me about Cinnamon AI company")
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
openai_completion.assert_called()
|
openai_completion.assert_called()
|
||||||
assert response.output == FINAL_RESPONSE_TEXT
|
assert response.text == FINAL_RESPONSE_TEXT
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
|
@ -224,4 +224,4 @@ def test_react_agent_with_langchain_tools(openai_completion, llm):
|
||||||
|
|
||||||
response = agent("Tell me about Cinnamon AI company")
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
openai_completion.assert_called()
|
openai_completion.assert_called()
|
||||||
assert response.output == FINAL_RESPONSE_TEXT
|
assert response.text == FINAL_RESPONSE_TEXT
|
||||||
|
|
61
tests/test_citation.py
Normal file
61
tests/test_citation.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# flake8: noqa
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.pipelines.citation import CitationPipeline
|
||||||
|
|
||||||
|
function_output = '{\n "question": "What is the provided _example_ benefits?",\n "answer": [\n {\n "fact": "特約死亡保険金: 被保険者がこの特約の保険期間中に死亡したときに支払います。",\n "substring_quote": ["特約死亡保険金"]\n },\n {\n "fact": "特約特定疾病保険金: 被保険者がこの特約の保険期間中に特定の疾病(悪性新生物(がん)、急性心筋梗塞または脳卒中)により所定の状態に該当したときに支払います。",\n "substring_quote": ["特約特定疾病保険金"]\n },\n {\n "fact": "特約障害保険金: 被保険者がこの特約の保険期間中に傷害もしくは疾病により所定の身体障害の状態に該当したとき、または不慮の事故により所定の身体障害の状態に該当したときに支払います。",\n "substring_quote": ["特約障害保険金"]\n },\n {\n "fact": "特約介護保険金: 被保険者がこの特約の保険期間中に傷害または疾病により所定の要介護状態に該当したときに支払います。",\n "substring_quote": ["特約介護保険金"]\n }\n ]\n}'
|
||||||
|
|
||||||
|
_openai_chat_completion_response = [
|
||||||
|
ChatCompletion.parse_obj(
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1692338378,
|
||||||
|
"model": "gpt-35-turbo",
|
||||||
|
"system_fingerprint": None,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "function_call",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"function_call": {
|
||||||
|
"arguments": function_output,
|
||||||
|
"name": "QuestionAnswer",
|
||||||
|
},
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm():
|
||||||
|
return AzureChatOpenAI(
|
||||||
|
azure_endpoint="https://dummy.openai.azure.com/",
|
||||||
|
openai_api_key="dummy",
|
||||||
|
openai_api_version="2023-03-15-preview",
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.resources.chat.completions.Completions.create",
|
||||||
|
side_effect=_openai_chat_completion_response,
|
||||||
|
)
|
||||||
|
def test_citation(openai_completion, llm):
|
||||||
|
question = "test query"
|
||||||
|
context = "document context"
|
||||||
|
|
||||||
|
citation = CitationPipeline(llm=llm)
|
||||||
|
result = citation(context, question)
|
||||||
|
assert len(result.answer) == 4
|
Loading…
Reference in New Issue
Block a user