[AUR-395, AUR-415] Adopt Example1 Injury pipeline; add .flow() for enabling bottom-up pipeline execution (#32)

* add example1/injury pipeline example
* add dotenv
* update various api
This commit is contained in:
ian_Cin
2023-10-02 16:24:56 +07:00
committed by GitHub
parent 3cceec63ef
commit d83c22aa4e
16 changed files with 114 additions and 69 deletions

View File

@@ -16,6 +16,19 @@ class BaseComponent(Compose):
- is_batch: check if input is batch
"""
inflow = None
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
return self.__call__(self.inflow.flow())
@abstractmethod
def run_raw(self, *args, **kwargs):
...

View File

@@ -1,25 +1,13 @@
from dataclasses import dataclass, field
from typing import List
from ..base import BaseComponent
from pydantic import Field
from kotaemon.documents.base import Document
@dataclass
class LLMInterface:
text: List[str]
class LLMInterface(Document):
candidates: List[str]
completion_tokens: int = -1
total_tokens: int = -1
prompt_tokens: int = -1
logits: List[List[float]] = field(default_factory=list)
class PromptTemplate(BaseComponent):
pass
class Extract(BaseComponent):
pass
class PromptNode(BaseComponent):
pass
logits: List[List[float]] = Field(default_factory=list)

View File

@@ -11,7 +11,17 @@ Message = TypeVar("Message", bound=BaseMessage)
class ChatLLM(BaseComponent):
...
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
text = self.inflow.flow().text
return self.__call__(text)
class LangchainChatLLM(ChatLLM):
@@ -44,8 +54,10 @@ class LangchainChatLLM(ChatLLM):
def run_document(self, text: List[Message], **kwargs) -> LLMInterface:
pred = self.agent.generate([text], **kwargs) # type: ignore
all_text = [each.text for each in pred.generations[0]]
return LLMInterface(
text=[each.text for each in pred.generations[0]],
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],

View File

@@ -33,8 +33,10 @@ class LangchainLLM(LLM):
def run_raw(self, text: str) -> LLMInterface:
pred = self.agent.generate([text])
all_text = [each.text for each in pred.generations[0]]
return LLMInterface(
text=[each.text for each in pred.generations[0]],
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],

View File

@@ -162,7 +162,7 @@ class ReactAgent(BaseAgent):
prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}")
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
response_text = response.text[0]
response_text = response.text
logging.info(f"Response: {response_text}")
action_step = self._parse_output(response_text)
if action_step is None:

View File

@@ -245,7 +245,7 @@ class RewooAgent(BaseAgent):
# Plan
planner_output = planner(instruction)
plannner_text_output = planner_output.text[0]
plannner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences(
plannner_text_output
@@ -263,7 +263,7 @@ class RewooAgent(BaseAgent):
# Solve
solver_output = solver(instruction, worker_log)
solver_output_text = solver_output.text[0]
solver_output_text = solver_output.text
return AgentOutput(
output=solver_output_text, cost=total_cost, token_usage=total_token

View File

@@ -50,9 +50,9 @@ class RegexExtractor(BaseComponent):
if not output_map:
return text
return output_map.get(text, text)
return str(output_map.get(text, text))
def run_raw(self, text: str) -> List[str]:
def run_raw(self, text: str) -> List[Document]:
"""
Runs the raw text through the static pattern and output mapping, returning a
list of strings.
@@ -66,9 +66,12 @@ class RegexExtractor(BaseComponent):
output = self.run_raw_static(self.pattern, text)
output = [self.map_output(text, self.output_map) for text in output]
return output
return [
Document(text=text, metadata={"origin": "RegexExtractor"})
for text in output
]
def run_batch_raw(self, text_batch: List[str]) -> List[List[str]]:
def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]:
"""
Runs a batch of raw text inputs through the `run_raw()` method and returns the
output for each input.
@@ -95,13 +98,7 @@ class RegexExtractor(BaseComponent):
Returns:
List[Document]: A list of extracted documents.
"""
texts = self.run_raw(document.text)
output = [
Document(text=text, metadata={**document.metadata, "RegexExtractor": True})
for text in texts
]
return output
return self.run_raw(document.text)
def run_batch_document(
self, document_batch: List[Document]

View File

@@ -1,3 +1,4 @@
import warnings
from typing import Union
from kotaemon.base import BaseComponent
@@ -5,7 +6,7 @@ from kotaemon.documents.base import Document
from kotaemon.prompt.template import PromptTemplate
class BasePrompt(BaseComponent):
class BasePromptComponent(BaseComponent):
"""
Base class for prompt components.
@@ -15,6 +16,16 @@ class BasePrompt(BaseComponent):
given template.
"""
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def __check_redundant_kwargs(self, **kwargs):
"""
Check for redundant keyword arguments.
@@ -33,7 +44,9 @@ class BasePrompt(BaseComponent):
redundant_keys = provided_keys - expected_keys
if redundant_keys:
raise ValueError(f"\nKeys provided but not in template: {redundant_keys}")
warnings.warn(
f"Keys provided but not in template: {redundant_keys}", UserWarning
)
def __check_unset_placeholders(self):
"""
@@ -111,27 +124,34 @@ class BasePrompt(BaseComponent):
Returns:
dict: A dictionary of keyword arguments.
"""
def __prepare(key, value):
if isinstance(value, str):
return value
if isinstance(value, (int, Document)):
return str(value)
raise ValueError(
f"Unsupported type {type(value)} for template value of key {key}"
)
kwargs = {}
for k in self.template.placeholders:
v = getattr(self, k)
if isinstance(v, (int, Document)):
v = str(v)
elif isinstance(v, BaseComponent):
v = str(v())
if isinstance(v, BaseComponent):
v = v()
if isinstance(v, list):
v = str([__prepare(k, each) for each in v])
elif isinstance(v, (str, int, Document)):
v = __prepare(k, v)
else:
raise ValueError(
f"Unsupported type {type(v)} for template value of key {k}"
)
kwargs[k] = v
return kwargs
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def set(self, **kwargs):
"""
Similar to `__set` but for external use.
@@ -163,7 +183,8 @@ class BasePrompt(BaseComponent):
self.__check_unset_placeholders()
prepared_kwargs = self.__prepare_value()
return self.template.populate(**prepared_kwargs)
text = self.template.populate(**prepared_kwargs)
return Document(text=text, metadata={"origin": "PromptComponent"})
def run_raw(self, *args, **kwargs):
pass
@@ -182,3 +203,6 @@ class BasePrompt(BaseComponent):
def is_batch(self, *args, **kwargs):
pass
def flow(self):
return self.__call__()