[AUR-361] Setup pre-commit, pytest, GitHub actions, ssh-secret (#3)

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2023-08-30 07:22:01 +07:00
committed by GitHub
parent c3c25db48c
commit 5241edbc46
19 changed files with 268 additions and 54 deletions

View File

@@ -1,17 +1,12 @@
from typing import Type, TypeVar
from typing import List, Type, TypeVar
from theflow.base import Param
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import (
BaseMessage,
HumanMessage,
)
from langchain.schema.messages import BaseMessage, HumanMessage
from theflow.base import Param
from ...components import BaseComponent
from ..base import LLMInterface
Message = TypeVar("Message", bound=BaseMessage)
@@ -43,11 +38,11 @@ class LangchainChatLLM(ChatLLM):
message = HumanMessage(content=text)
return self.run_document([message])
def run_batch_raw(self, text: list[str]) -> list[LLMInterface]:
def run_batch_raw(self, text: List[str]) -> List[LLMInterface]:
inputs = [[HumanMessage(content=each)] for each in text]
return self.run_batch_document(inputs)
def run_document(self, text: list[Message]) -> LLMInterface:
def run_document(self, text: List[Message]) -> LLMInterface:
pred = self.agent.generate([text])
return LLMInterface(
text=[each.text for each in pred.generations[0]],
@@ -57,7 +52,7 @@ class LangchainChatLLM(ChatLLM):
logits=[],
)
def run_batch_document(self, text: list[list[Message]]) -> list[LLMInterface]:
def run_batch_document(self, text: List[List[Message]]) -> List[LLMInterface]:
outputs = []
for each_text in text:
outputs.append(self.run_document(each_text))
@@ -66,14 +61,14 @@ class LangchainChatLLM(ChatLLM):
def is_document(self, text) -> bool:
if isinstance(text, str):
return False
elif isinstance(text, list) and isinstance(text[0], str):
elif isinstance(text, List) and isinstance(text[0], str):
return False
return True
def is_batch(self, text) -> bool:
if isinstance(text, str):
return False
elif isinstance(text, list):
elif isinstance(text, List):
if isinstance(text[0], BaseMessage):
return False
return True