[AUR-361] Setup pre-commit, pytest, GitHub actions, ssh-secret (#3)
Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
@@ -1,17 +1,12 @@
|
||||
from typing import Type, TypeVar
|
||||
from typing import List, Type, TypeVar
|
||||
|
||||
from theflow.base import Param
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from langchain.schema.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from theflow.base import Param
|
||||
|
||||
from ...components import BaseComponent
|
||||
from ..base import LLMInterface
|
||||
|
||||
|
||||
Message = TypeVar("Message", bound=BaseMessage)
|
||||
|
||||
|
||||
@@ -43,11 +38,11 @@ class LangchainChatLLM(ChatLLM):
|
||||
message = HumanMessage(content=text)
|
||||
return self.run_document([message])
|
||||
|
||||
def run_batch_raw(self, text: list[str]) -> list[LLMInterface]:
|
||||
def run_batch_raw(self, text: List[str]) -> List[LLMInterface]:
|
||||
inputs = [[HumanMessage(content=each)] for each in text]
|
||||
return self.run_batch_document(inputs)
|
||||
|
||||
def run_document(self, text: list[Message]) -> LLMInterface:
|
||||
def run_document(self, text: List[Message]) -> LLMInterface:
|
||||
pred = self.agent.generate([text])
|
||||
return LLMInterface(
|
||||
text=[each.text for each in pred.generations[0]],
|
||||
@@ -57,7 +52,7 @@ class LangchainChatLLM(ChatLLM):
|
||||
logits=[],
|
||||
)
|
||||
|
||||
def run_batch_document(self, text: list[list[Message]]) -> list[LLMInterface]:
|
||||
def run_batch_document(self, text: List[List[Message]]) -> List[LLMInterface]:
|
||||
outputs = []
|
||||
for each_text in text:
|
||||
outputs.append(self.run_document(each_text))
|
||||
@@ -66,14 +61,14 @@ class LangchainChatLLM(ChatLLM):
|
||||
def is_document(self, text) -> bool:
|
||||
if isinstance(text, str):
|
||||
return False
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
elif isinstance(text, List) and isinstance(text[0], str):
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_batch(self, text) -> bool:
|
||||
if isinstance(text, str):
|
||||
return False
|
||||
elif isinstance(text, list):
|
||||
elif isinstance(text, List):
|
||||
if isinstance(text[0], BaseMessage):
|
||||
return False
|
||||
return True
|
||||
|
Reference in New Issue
Block a user