Add Citation pipeline (#78)
* add rerankers in retrieving pipeline * update example MVP pipeline * add citation pipeline and function call interface * change return type of QA and AgentPipeline to Document
This commit is contained in:
committed by
GitHub
parent
f8b8d86d4e
commit
cc1e75b3c6
@@ -6,11 +6,12 @@ from theflow import Node
|
||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import RetrievedDocument
|
||||
from kotaemon.base.schema import Document, RetrievedDocument
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.pipelines.agents import BaseAgent
|
||||
from kotaemon.pipelines.citation import CitationPipeline
|
||||
from kotaemon.pipelines.reranking import BaseRerankingPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.pipelines.tools import ComponentTool
|
||||
@@ -40,10 +41,10 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
)
|
||||
|
||||
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
openai_api_version="2023-07-01-preview",
|
||||
deployment_name="dummy-q2-16k",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
@@ -90,7 +91,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
]
|
||||
return "\n\n".join(matched_texts)
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
def run(self, question: str, use_citation: bool = False) -> Document:
|
||||
# retrieve relevant documents as context
|
||||
documents = self.retrieving_pipeline(question, top_k=int(self.retrieval_top_k))
|
||||
context = self._format_retrieved_context(documents)
|
||||
@@ -102,7 +103,15 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
question=question,
|
||||
)
|
||||
self.log_progress(".prompt", prompt=prompt)
|
||||
answer = self.llm(prompt).text
|
||||
answer_text = self.llm(prompt).text
|
||||
if use_citation:
|
||||
# run citation pipeline
|
||||
citation_pipeline = CitationPipeline(llm=self.llm)
|
||||
citation = citation_pipeline(context=context, question=question)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
answer = Document(text=answer_text, metadata={"citation": citation})
|
||||
return answer
|
||||
|
||||
|
||||
@@ -130,6 +139,6 @@ class AgentQAPipeline(QuestionAnsweringPipeline):
|
||||
if search_tool not in self.agent.plugins:
|
||||
self.agent.plugins.append(search_tool)
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
answer = self.agent(question).output
|
||||
def run(self, question: str, use_citation: bool = False) -> Document:
|
||||
answer = self.agent(question, use_citation=use_citation)
|
||||
return answer
|
||||
|
Reference in New Issue
Block a user