kotaemon/knowledgehub/agents/io/base.py
ian_Cin 797df5a69c refractor agents (#100)
* refractor agents

* minor cosmetic, add terminal ui for cli

* pump to 0.3.4

* Add temporary path

* fix unclose files in tests

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
2023-12-06 17:06:29 +07:00

256 lines
5.6 KiB
Python

import json
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
from pydantic import Extra
from kotaemon.base import LLMInterface
def check_log():
"""
Checks if logging has been enabled.
:return: True if logging has been enabled, False otherwise.
:rtype: bool
"""
return os.environ.get("LOG_PATH", None) is not None
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
class BaseScratchPad:
"""
Base class for output handlers.
Attributes:
-----------
logger : logging.Logger
The logger object to log messages.
Methods:
--------
stop():
Stop the output.
update_status(output: str, **kwargs):
Update the status of the output.
thinking(name: str):
Log that a process is thinking.
done(_all=False):
Log that the process is done.
stream_print(item: str):
Not implemented.
json_print(item: Dict[str, Any]):
Log a JSON object.
panel_print(item: Any, title: str = "Output", stream: bool = False):
Log a panel output.
clear():
Not implemented.
print(content: str, **kwargs):
Log arbitrary content.
format_json(json_obj: str):
Format a JSON object.
debug(content: str, **kwargs):
Log a debug message.
info(content: str, **kwargs):
Log an informational message.
warning(content: str, **kwargs):
Log a warning message.
error(content: str, **kwargs):
Log an error message.
critical(content: str, **kwargs):
Log a critical message.
"""
def __init__(self):
"""
Initialize the BaseOutput object.
"""
self.logger = logging
self.log = []
def stop(self):
"""
Stop the output.
"""
def update_status(self, output: str, **kwargs):
"""
Update the status of the output.
"""
if check_log():
self.logger.info(output)
def thinking(self, name: str):
"""
Log that a process is thinking.
"""
if check_log():
self.logger.info(f"{name} is thinking...")
def done(self, _all=False):
"""
Log that the process is done.
"""
if check_log():
self.logger.info("Done")
def stream_print(self, item: str):
"""
Stream print.
"""
def json_print(self, item: Dict[str, Any]):
"""
Log a JSON object.
"""
if check_log():
self.logger.info(json.dumps(item, indent=2))
def panel_print(self, item: Any, title: str = "Output", stream: bool = False):
"""
Log a panel output.
Args:
item : Any
The item to log.
title : str, optional
The title of the panel, defaults to "Output".
stream : bool, optional
"""
if not stream:
self.log.append(item)
if check_log():
self.logger.info("-" * 20)
self.logger.info(item)
self.logger.info("-" * 20)
def clear(self):
"""
Not implemented.
"""
def print(self, content: str, **kwargs):
"""
Log arbitrary content.
"""
self.log.append(content)
if check_log():
self.logger.info(content)
def format_json(self, json_obj: str):
"""
Format a JSON object.
"""
formatted_json = json.dumps(json_obj, indent=2)
return formatted_json
def debug(self, content: str, **kwargs):
"""
Log a debug message.
"""
if check_log():
self.logger.debug(content, **kwargs)
def info(self, content: str, **kwargs):
"""
Log an informational message.
"""
if check_log():
self.logger.info(content, **kwargs)
def warning(self, content: str, **kwargs):
"""
Log a warning message.
"""
if check_log():
self.logger.warning(content, **kwargs)
def error(self, content: str, **kwargs):
"""
Log an error message.
"""
if check_log():
self.logger.error(content, **kwargs)
def critical(self, content: str, **kwargs):
"""
Log a critical message.
"""
if check_log():
self.logger.critical(content, **kwargs)
@dataclass
class AgentAction:
"""Agent's action to take.
Args:
tool: The tool to invoke.
tool_input: The input to the tool.
log: The log message.
"""
tool: str
tool_input: Union[str, dict]
log: str
class AgentFinish(NamedTuple):
"""Agent's return value when finishing execution.
Args:
return_values: The return values of the agent.
log: The log message.
"""
return_values: dict
log: str
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
"""Output from an agent.
Args:
text: The text output from the agent.
agent_type: The type of agent.
status: The status after executing the agent.
error: The error message if any.
"""
text: str
type: str = "agent"
agent_type: AgentType
status: Literal["finished", "stopped", "failed"]
error: Optional[str] = None