182 lines
5.9 KiB
Python
182 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, AsyncGenerator
|
|
|
|
import anyio
|
|
from gradio import ChatInterface
|
|
from gradio.components import Component, get_component_instance
|
|
from gradio.events import on
|
|
from gradio.helpers import special_args
|
|
from gradio.routes import Request
|
|
|
|
|
|
class ChatBlock(ChatInterface):
|
|
"""The ChatBlock subclasses ChatInterface to provide extra functionalities:
|
|
|
|
- Show additional outputs to the chat interface
|
|
- Disallow blank user message
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
additional_outputs: str | Component | list[str | Component] | None = None,
|
|
**kwargs,
|
|
):
|
|
if additional_outputs:
|
|
if not isinstance(additional_outputs, list):
|
|
additional_outputs = [additional_outputs]
|
|
self.additional_outputs = [
|
|
get_component_instance(i) for i in additional_outputs # type: ignore
|
|
]
|
|
else:
|
|
self.additional_outputs = []
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
async def _submit_fn(
|
|
self,
|
|
message: str,
|
|
history_with_input: list[list[str | None]],
|
|
request: Request,
|
|
*args,
|
|
) -> tuple[Any, ...]:
|
|
input_args = args[: -len(self.additional_outputs)]
|
|
output_args = args[-len(self.additional_outputs) :]
|
|
if not message:
|
|
return history_with_input, history_with_input, *output_args
|
|
|
|
history = history_with_input[:-1]
|
|
inputs, _, _ = special_args(
|
|
self.fn, inputs=[message, history, *input_args], request=request
|
|
)
|
|
|
|
if self.is_async:
|
|
response = await self.fn(*inputs)
|
|
else:
|
|
response = await anyio.to_thread.run_sync(
|
|
self.fn, *inputs, limiter=self.limiter
|
|
)
|
|
|
|
output = []
|
|
if self.additional_outputs:
|
|
text = response[0]
|
|
output = response[1:]
|
|
else:
|
|
text = response
|
|
|
|
history.append([message, text])
|
|
return history, history, *output
|
|
|
|
async def _stream_fn(
|
|
self,
|
|
message: str,
|
|
history_with_input: list[list[str | None]],
|
|
*args,
|
|
) -> AsyncGenerator:
|
|
raise NotImplementedError("Stream function not implemented for ChatBlock")
|
|
|
|
def _display_input(
|
|
self, message: str, history: list[list[str | None]]
|
|
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
|
"""Stop displaying the input message if the message is a blank string"""
|
|
if not message:
|
|
return history, history
|
|
return super()._display_input(message, history)
|
|
|
|
def _setup_events(self) -> None:
|
|
"""Include additional outputs in the submit event"""
|
|
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
|
submit_triggers = (
|
|
[self.textbox.submit, self.submit_btn.click]
|
|
if self.submit_btn
|
|
else [self.textbox.submit]
|
|
)
|
|
submit_event = (
|
|
on(
|
|
submit_triggers,
|
|
self._clear_and_save_textbox,
|
|
[self.textbox],
|
|
[self.textbox, self.saved_input],
|
|
api_name=False,
|
|
queue=False,
|
|
)
|
|
.then(
|
|
self._display_input,
|
|
[self.saved_input, self.chatbot_state],
|
|
[self.chatbot, self.chatbot_state],
|
|
api_name=False,
|
|
queue=False,
|
|
)
|
|
.then(
|
|
submit_fn,
|
|
[self.saved_input, self.chatbot_state]
|
|
+ self.additional_inputs
|
|
+ self.additional_outputs,
|
|
[self.chatbot, self.chatbot_state] + self.additional_outputs,
|
|
api_name=False,
|
|
)
|
|
)
|
|
self._setup_stop_events(submit_triggers, submit_event)
|
|
|
|
if self.retry_btn:
|
|
retry_event = (
|
|
self.retry_btn.click(
|
|
self._delete_prev_fn,
|
|
[self.chatbot_state],
|
|
[self.chatbot, self.saved_input, self.chatbot_state],
|
|
api_name=False,
|
|
queue=False,
|
|
)
|
|
.then(
|
|
self._display_input,
|
|
[self.saved_input, self.chatbot_state],
|
|
[self.chatbot, self.chatbot_state],
|
|
api_name=False,
|
|
queue=False,
|
|
)
|
|
.then(
|
|
submit_fn,
|
|
[self.saved_input, self.chatbot_state]
|
|
+ self.additional_inputs
|
|
+ self.additional_outputs,
|
|
[self.chatbot, self.chatbot_state] + self.additional_outputs,
|
|
api_name=False,
|
|
)
|
|
)
|
|
self._setup_stop_events([self.retry_btn.click], retry_event)
|
|
|
|
if self.undo_btn:
|
|
self.undo_btn.click(
|
|
self._delete_prev_fn,
|
|
[self.chatbot_state],
|
|
[self.chatbot, self.saved_input, self.chatbot_state],
|
|
api_name=False,
|
|
queue=False,
|
|
).then(
|
|
lambda x: x,
|
|
[self.saved_input],
|
|
[self.textbox],
|
|
api_name=False,
|
|
queue=False,
|
|
)
|
|
|
|
if self.clear_btn:
|
|
self.clear_btn.click(
|
|
lambda: ([], [], None),
|
|
None,
|
|
[self.chatbot, self.chatbot_state, self.saved_input],
|
|
queue=False,
|
|
api_name=False,
|
|
)
|
|
|
|
def _setup_api(self) -> None:
|
|
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
|
|
|
|
self.fake_api_btn.click(
|
|
api_fn,
|
|
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
|
[self.textbox, self.chatbot_state] + self.additional_outputs,
|
|
api_name="chat",
|
|
)
|