Add Citation pipeline (#78)

* add rerankers in retrieving pipeline

* update example MVP pipeline

* add citation pipeline and function call interface

* change return type of QA and AgentPipeline to Document
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2023-11-16 11:24:35 +07:00
committed by GitHub
parent f8b8d86d4e
commit cc1e75b3c6
9 changed files with 223 additions and 19 deletions

View File

@@ -6,11 +6,12 @@ from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.base.schema import RetrievedDocument
from kotaemon.base.schema import Document, RetrievedDocument
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.citation import CitationPipeline
from kotaemon.pipelines.reranking import BaseRerankingPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool
@@ -40,10 +41,10 @@ class QuestionAnsweringPipeline(BaseComponent):
)
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
azure_endpoint="https://bleh-dummy.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
openai_api_version="2023-07-01-preview",
deployment_name="dummy-q2-16k",
temperature=0,
request_timeout=60,
)
@@ -90,7 +91,7 @@ class QuestionAnsweringPipeline(BaseComponent):
]
return "\n\n".join(matched_texts)
def run(self, question: str) -> str:
def run(self, question: str, use_citation: bool = False) -> Document:
# retrieve relevant documents as context
documents = self.retrieving_pipeline(question, top_k=int(self.retrieval_top_k))
context = self._format_retrieved_context(documents)
@@ -102,7 +103,15 @@ class QuestionAnsweringPipeline(BaseComponent):
question=question,
)
self.log_progress(".prompt", prompt=prompt)
answer = self.llm(prompt).text
answer_text = self.llm(prompt).text
if use_citation:
# run citation pipeline
citation_pipeline = CitationPipeline(llm=self.llm)
citation = citation_pipeline(context=context, question=question)
else:
citation = None
answer = Document(text=answer_text, metadata={"citation": citation})
return answer
@@ -130,6 +139,6 @@ class AgentQAPipeline(QuestionAnsweringPipeline):
if search_tool not in self.agent.plugins:
self.agent.plugins.append(search_tool)
def run(self, question: str) -> str:
answer = self.agent(question).output
def run(self, question: str, use_citation: bool = False) -> Document:
answer = self.agent(question, use_citation=use_citation)
return answer