fix: openai async (#585) bump:patch

This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-12-24 09:45:15 +07:00 committed by GitHub
parent 95191f53d9
commit 5343d0d3ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 7 deletions

View File

@ -196,6 +196,10 @@ class BaseChatOpenAI(ChatLLM):
"""Get the openai response""" """Get the openai response"""
raise NotImplementedError raise NotImplementedError
async def aopenai_response(self, client, **kwargs):
"""Get the openai response"""
raise NotImplementedError
def invoke( def invoke(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
) -> LLMInterface: ) -> LLMInterface:
@ -211,8 +215,10 @@ class BaseChatOpenAI(ChatLLM):
) -> LLMInterface: ) -> LLMInterface:
client = self.prepare_client(async_version=True) client = self.prepare_client(async_version=True)
input_messages = self.prepare_message(messages) input_messages = self.prepare_message(messages)
resp = await self.openai_response( resp = (
await self.aopenai_response(
client, messages=input_messages, stream=False, **kwargs client, messages=input_messages, stream=False, **kwargs
)
).dict() ).dict()
return self.prepare_output(resp) return self.prepare_output(resp)
@ -290,8 +296,7 @@ class ChatOpenAI(BaseChatOpenAI):
return OpenAI(**params) return OpenAI(**params)
def openai_response(self, client, **kwargs): def prepare_params(self, **kwargs):
"""Get the openai response"""
if "tools_pydantic" in kwargs: if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic") kwargs.pop("tools_pydantic")
@ -313,8 +318,17 @@ class ChatOpenAI(BaseChatOpenAI):
params = {k: v for k, v in params_.items() if v is not None} params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs) params.update(kwargs)
return params
def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = self.prepare_params(**kwargs)
return client.chat.completions.create(**params) return client.chat.completions.create(**params)
async def aopenai_response(self, client, **kwargs):
params = self.prepare_params(**kwargs)
return await client.chat.completions.create(**params)
class AzureChatOpenAI(BaseChatOpenAI): class AzureChatOpenAI(BaseChatOpenAI):
"""OpenAI chat model provided by Microsoft Azure""" """OpenAI chat model provided by Microsoft Azure"""
@ -361,8 +375,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
return AzureOpenAI(**params) return AzureOpenAI(**params)
def openai_response(self, client, **kwargs): def prepare_params(self, **kwargs):
"""Get the openai response"""
if "tools_pydantic" in kwargs: if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic") kwargs.pop("tools_pydantic")
@ -384,4 +397,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
params = {k: v for k, v in params_.items() if v is not None} params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs) params.update(kwargs)
return params
def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = self.prepare_params(**kwargs)
return client.chat.completions.create(**params) return client.chat.completions.create(**params)
async def aopenai_response(self, client, **kwargs):
params = self.prepare_params(**kwargs)
return await client.chat.completions.create(**params)

View File

@ -26,7 +26,7 @@ class ChatPanel(BasePage):
scale=20, scale=20,
file_count="multiple", file_count="multiple",
placeholder=( placeholder=(
"Type a message, or search the @web, " "tag a file with @filename" "Type a message, search the @web, or tag a file with @filename"
), ),
container=False, container=False,
show_label=False, show_label=False,