[AUR-395, AUR-415] Adopt Example1 Injury pipeline; add .flow() for enabling bottom-up pipeline execution (#32)
* add example1/injury pipeline example * add dotenv * update various api
This commit is contained in:
@@ -1,25 +1,13 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from ..base import BaseComponent
|
||||
from pydantic import Field
|
||||
|
||||
from kotaemon.documents.base import Document
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInterface:
|
||||
text: List[str]
|
||||
class LLMInterface(Document):
|
||||
candidates: List[str]
|
||||
completion_tokens: int = -1
|
||||
total_tokens: int = -1
|
||||
prompt_tokens: int = -1
|
||||
logits: List[List[float]] = field(default_factory=list)
|
||||
|
||||
|
||||
class PromptTemplate(BaseComponent):
|
||||
pass
|
||||
|
||||
|
||||
class Extract(BaseComponent):
|
||||
pass
|
||||
|
||||
|
||||
class PromptNode(BaseComponent):
|
||||
pass
|
||||
logits: List[List[float]] = Field(default_factory=list)
|
||||
|
@@ -11,7 +11,17 @@ Message = TypeVar("Message", bound=BaseMessage)
|
||||
|
||||
|
||||
class ChatLLM(BaseComponent):
|
||||
...
|
||||
def flow(self):
|
||||
if self.inflow is None:
|
||||
raise ValueError("No inflow provided.")
|
||||
|
||||
if not isinstance(self.inflow, BaseComponent):
|
||||
raise ValueError(
|
||||
f"inflow must be a BaseComponent, found {type(self.inflow)}"
|
||||
)
|
||||
|
||||
text = self.inflow.flow().text
|
||||
return self.__call__(text)
|
||||
|
||||
|
||||
class LangchainChatLLM(ChatLLM):
|
||||
@@ -44,8 +54,10 @@ class LangchainChatLLM(ChatLLM):
|
||||
|
||||
def run_document(self, text: List[Message], **kwargs) -> LLMInterface:
|
||||
pred = self.agent.generate([text], **kwargs) # type: ignore
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
return LLMInterface(
|
||||
text=[each.text for each in pred.generations[0]],
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
candidates=all_text,
|
||||
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
|
||||
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
|
||||
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],
|
||||
|
@@ -33,8 +33,10 @@ class LangchainLLM(LLM):
|
||||
|
||||
def run_raw(self, text: str) -> LLMInterface:
|
||||
pred = self.agent.generate([text])
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
return LLMInterface(
|
||||
text=[each.text for each in pred.generations[0]],
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
candidates=all_text,
|
||||
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
|
||||
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
|
||||
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],
|
||||
|
Reference in New Issue
Block a user