Update retrieving + agent pipeline (#71)
This commit is contained in:
committed by
GitHub
parent
693ed39de4
commit
640962e916
@@ -7,14 +7,14 @@ from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import RetrievedDocument
|
||||
from kotaemon.docstores import InMemoryDocumentStore
|
||||
from kotaemon.docstores import BaseDocumentStore, InMemoryDocumentStore
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.pipelines.agents import BaseAgent
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.pipelines.tools import ComponentTool
|
||||
from kotaemon.vectorstores import InMemoryVectorStore
|
||||
from kotaemon.vectorstores import BaseVectorStore, InMemoryVectorStore
|
||||
|
||||
from .utils import file_names_to_collection_name
|
||||
|
||||
@@ -29,7 +29,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
file_name_list: List[str]
|
||||
"""List of filename, incombination with storage_path to
|
||||
create persistent path of vectorstore"""
|
||||
prompt_template: PromptTemplate = PromptTemplate(
|
||||
qa_prompt_template: PromptTemplate = PromptTemplate(
|
||||
'Answer the following question: "{question}". '
|
||||
"The context is: \n{context}\nAnswer: "
|
||||
)
|
||||
@@ -43,8 +43,8 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore)
|
||||
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
|
||||
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
||||
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
||||
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
@@ -53,12 +53,21 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
)
|
||||
|
||||
@Node.default()
|
||||
@Node.auto(
|
||||
depends_on=[
|
||||
"vector_store",
|
||||
"doc_store",
|
||||
"embedding",
|
||||
"file_name_list",
|
||||
"retrieval_top_k",
|
||||
]
|
||||
)
|
||||
def retrieving_pipeline(self) -> RetrieveDocumentFromVectorStorePipeline:
|
||||
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||
vector_store=self.vector_store,
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
top_k=self.retrieval_top_k,
|
||||
)
|
||||
# load persistent from selected path
|
||||
collection_name = file_names_to_collection_name(self.file_name_list)
|
||||
@@ -81,7 +90,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
self.log_progress(".context", context=context)
|
||||
|
||||
# generate the answer
|
||||
prompt = self.prompt_template.populate(
|
||||
prompt = self.qa_prompt_template.populate(
|
||||
context=context,
|
||||
question=question,
|
||||
)
|
||||
|
Reference in New Issue
Block a user