[AUR-395, AUR-415] Adopt Example1 Injury pipeline; add .flow() for enabling bottom-up pipeline execution (#32)
* add example1/injury pipeline example * add dotenv * update various api
This commit is contained in:
@@ -16,6 +16,19 @@ class BaseComponent(Compose):
|
||||
- is_batch: check if input is batch
|
||||
"""
|
||||
|
||||
inflow = None
|
||||
|
||||
def flow(self):
|
||||
if self.inflow is None:
|
||||
raise ValueError("No inflow provided.")
|
||||
|
||||
if not isinstance(self.inflow, BaseComponent):
|
||||
raise ValueError(
|
||||
f"inflow must be a BaseComponent, found {type(self.inflow)}"
|
||||
)
|
||||
|
||||
return self.__call__(self.inflow.flow())
|
||||
|
||||
@abstractmethod
|
||||
def run_raw(self, *args, **kwargs):
|
||||
...
|
||||
|
@@ -1,25 +1,13 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from ..base import BaseComponent
|
||||
from pydantic import Field
|
||||
|
||||
from kotaemon.documents.base import Document
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInterface:
|
||||
text: List[str]
|
||||
class LLMInterface(Document):
|
||||
candidates: List[str]
|
||||
completion_tokens: int = -1
|
||||
total_tokens: int = -1
|
||||
prompt_tokens: int = -1
|
||||
logits: List[List[float]] = field(default_factory=list)
|
||||
|
||||
|
||||
class PromptTemplate(BaseComponent):
|
||||
pass
|
||||
|
||||
|
||||
class Extract(BaseComponent):
|
||||
pass
|
||||
|
||||
|
||||
class PromptNode(BaseComponent):
|
||||
pass
|
||||
logits: List[List[float]] = Field(default_factory=list)
|
||||
|
@@ -11,7 +11,17 @@ Message = TypeVar("Message", bound=BaseMessage)
|
||||
|
||||
|
||||
class ChatLLM(BaseComponent):
|
||||
...
|
||||
def flow(self):
|
||||
if self.inflow is None:
|
||||
raise ValueError("No inflow provided.")
|
||||
|
||||
if not isinstance(self.inflow, BaseComponent):
|
||||
raise ValueError(
|
||||
f"inflow must be a BaseComponent, found {type(self.inflow)}"
|
||||
)
|
||||
|
||||
text = self.inflow.flow().text
|
||||
return self.__call__(text)
|
||||
|
||||
|
||||
class LangchainChatLLM(ChatLLM):
|
||||
@@ -44,8 +54,10 @@ class LangchainChatLLM(ChatLLM):
|
||||
|
||||
def run_document(self, text: List[Message], **kwargs) -> LLMInterface:
|
||||
pred = self.agent.generate([text], **kwargs) # type: ignore
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
return LLMInterface(
|
||||
text=[each.text for each in pred.generations[0]],
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
candidates=all_text,
|
||||
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
|
||||
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
|
||||
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],
|
||||
|
@@ -33,8 +33,10 @@ class LangchainLLM(LLM):
|
||||
|
||||
def run_raw(self, text: str) -> LLMInterface:
|
||||
pred = self.agent.generate([text])
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
return LLMInterface(
|
||||
text=[each.text for each in pred.generations[0]],
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
candidates=all_text,
|
||||
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
|
||||
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
|
||||
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],
|
||||
|
@@ -162,7 +162,7 @@ class ReactAgent(BaseAgent):
|
||||
prompt = self._compose_prompt(instruction)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
|
||||
response_text = response.text[0]
|
||||
response_text = response.text
|
||||
logging.info(f"Response: {response_text}")
|
||||
action_step = self._parse_output(response_text)
|
||||
if action_step is None:
|
||||
|
@@ -245,7 +245,7 @@ class RewooAgent(BaseAgent):
|
||||
|
||||
# Plan
|
||||
planner_output = planner(instruction)
|
||||
plannner_text_output = planner_output.text[0]
|
||||
plannner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
|
||||
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||
plannner_text_output
|
||||
@@ -263,7 +263,7 @@ class RewooAgent(BaseAgent):
|
||||
|
||||
# Solve
|
||||
solver_output = solver(instruction, worker_log)
|
||||
solver_output_text = solver_output.text[0]
|
||||
solver_output_text = solver_output.text
|
||||
|
||||
return AgentOutput(
|
||||
output=solver_output_text, cost=total_cost, token_usage=total_token
|
||||
|
@@ -50,9 +50,9 @@ class RegexExtractor(BaseComponent):
|
||||
if not output_map:
|
||||
return text
|
||||
|
||||
return output_map.get(text, text)
|
||||
return str(output_map.get(text, text))
|
||||
|
||||
def run_raw(self, text: str) -> List[str]:
|
||||
def run_raw(self, text: str) -> List[Document]:
|
||||
"""
|
||||
Runs the raw text through the static pattern and output mapping, returning a
|
||||
list of strings.
|
||||
@@ -66,9 +66,12 @@ class RegexExtractor(BaseComponent):
|
||||
output = self.run_raw_static(self.pattern, text)
|
||||
output = [self.map_output(text, self.output_map) for text in output]
|
||||
|
||||
return output
|
||||
return [
|
||||
Document(text=text, metadata={"origin": "RegexExtractor"})
|
||||
for text in output
|
||||
]
|
||||
|
||||
def run_batch_raw(self, text_batch: List[str]) -> List[List[str]]:
|
||||
def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]:
|
||||
"""
|
||||
Runs a batch of raw text inputs through the `run_raw()` method and returns the
|
||||
output for each input.
|
||||
@@ -95,13 +98,7 @@ class RegexExtractor(BaseComponent):
|
||||
Returns:
|
||||
List[Document]: A list of extracted documents.
|
||||
"""
|
||||
texts = self.run_raw(document.text)
|
||||
output = [
|
||||
Document(text=text, metadata={**document.metadata, "RegexExtractor": True})
|
||||
for text in texts
|
||||
]
|
||||
|
||||
return output
|
||||
return self.run_raw(document.text)
|
||||
|
||||
def run_batch_document(
|
||||
self, document_batch: List[Document]
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Union
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
@@ -5,7 +6,7 @@ from kotaemon.documents.base import Document
|
||||
from kotaemon.prompt.template import PromptTemplate
|
||||
|
||||
|
||||
class BasePrompt(BaseComponent):
|
||||
class BasePromptComponent(BaseComponent):
|
||||
"""
|
||||
Base class for prompt components.
|
||||
|
||||
@@ -15,6 +16,16 @@ class BasePrompt(BaseComponent):
|
||||
given template.
|
||||
"""
|
||||
|
||||
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
||||
super().__init__()
|
||||
self.template = (
|
||||
template
|
||||
if isinstance(template, PromptTemplate)
|
||||
else PromptTemplate(template)
|
||||
)
|
||||
|
||||
self.__set(**kwargs)
|
||||
|
||||
def __check_redundant_kwargs(self, **kwargs):
|
||||
"""
|
||||
Check for redundant keyword arguments.
|
||||
@@ -33,7 +44,9 @@ class BasePrompt(BaseComponent):
|
||||
|
||||
redundant_keys = provided_keys - expected_keys
|
||||
if redundant_keys:
|
||||
raise ValueError(f"\nKeys provided but not in template: {redundant_keys}")
|
||||
warnings.warn(
|
||||
f"Keys provided but not in template: {redundant_keys}", UserWarning
|
||||
)
|
||||
|
||||
def __check_unset_placeholders(self):
|
||||
"""
|
||||
@@ -111,27 +124,34 @@ class BasePrompt(BaseComponent):
|
||||
Returns:
|
||||
dict: A dictionary of keyword arguments.
|
||||
"""
|
||||
|
||||
def __prepare(key, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, (int, Document)):
|
||||
return str(value)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(value)} for template value of key {key}"
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
for k in self.template.placeholders:
|
||||
v = getattr(self, k)
|
||||
if isinstance(v, (int, Document)):
|
||||
v = str(v)
|
||||
elif isinstance(v, BaseComponent):
|
||||
v = str(v())
|
||||
if isinstance(v, BaseComponent):
|
||||
v = v()
|
||||
if isinstance(v, list):
|
||||
v = str([__prepare(k, each) for each in v])
|
||||
elif isinstance(v, (str, int, Document)):
|
||||
v = __prepare(k, v)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(v)} for template value of key {k}"
|
||||
)
|
||||
kwargs[k] = v
|
||||
|
||||
return kwargs
|
||||
|
||||
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
||||
super().__init__()
|
||||
self.template = (
|
||||
template
|
||||
if isinstance(template, PromptTemplate)
|
||||
else PromptTemplate(template)
|
||||
)
|
||||
|
||||
self.__set(**kwargs)
|
||||
|
||||
def set(self, **kwargs):
|
||||
"""
|
||||
Similar to `__set` but for external use.
|
||||
@@ -163,7 +183,8 @@ class BasePrompt(BaseComponent):
|
||||
self.__check_unset_placeholders()
|
||||
prepared_kwargs = self.__prepare_value()
|
||||
|
||||
return self.template.populate(**prepared_kwargs)
|
||||
text = self.template.populate(**prepared_kwargs)
|
||||
return Document(text=text, metadata={"origin": "PromptComponent"})
|
||||
|
||||
def run_raw(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -182,3 +203,6 @@ class BasePrompt(BaseComponent):
|
||||
|
||||
def is_batch(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def flow(self):
|
||||
return self.__call__()
|
||||
|
Reference in New Issue
Block a user