kotaemon/libs/ktem/ktem/llms/manager.py
Duc Nguyen (john) a8725710af
Allow users to select reasoning pipeline. Fix small issues with user UI, cohere name (#50)
* Fix user page

* Allow changing LLM in reasoning pipeline

* Fix CohereEmbedding name
2024-04-25 17:18:12 +07:00

211 lines
6.3 KiB
Python

from typing import Optional, Type, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from kotaemon.llms import ChatLLM
from .db import LLMTable, engine
class LLMManager:
"""Represent a pool of models"""
def __init__(self):
self._models: dict[str, ChatLLM] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
if hasattr(flowsettings, "KH_LLMS"):
for name, model in flowsettings.KH_LLMS.items():
with Session(engine) as session:
stmt = select(LLMTable).where(LLMTable.name == name)
result = session.execute(stmt)
if not result.first():
item = LLMTable(
name=name,
spec=model["spec"],
default=model.get("default", False),
)
session.add(item)
session.commit()
self.load()
self.load_vendors()
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
with Session(engine) as session:
stmt = select(LLMTable)
items = session.execute(stmt)
for (item,) in items:
self._models[item.name] = deserialize(item.spec, safe=False)
self._info[item.name] = {
"name": item.name,
"spec": item.spec,
"default": item.default,
}
if item.default:
self._default = item.name
def load_vendors(self):
from kotaemon.llms import (
AzureChatOpenAI,
ChatOpenAI,
EndpointChatLLM,
LlamaCppChat,
)
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
def __getitem__(self, key: str) -> ChatLLM:
"""Get model by name"""
return self._models[key]
def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models
@overload
def get(self, key: str, default: None) -> Optional[ChatLLM]:
...
@overload
def get(self, key: str, default: ChatLLM) -> ChatLLM:
...
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
"""Get model by name with default value"""
return self._models.get(key, default)
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
"label": "LLM",
"choices": list(self._models.keys()),
"value": self.get_default_name(),
}
def options(self) -> dict:
"""Present a dict of models"""
return self._models
def get_random_name(self) -> str:
"""Get the name of random model
Returns:
str: random model name in the pool
"""
import random
if not self._models:
raise ValueError("No models in pool")
return random.choice(list(self._models.keys()))
def get_default_name(self) -> str:
"""Get the name of default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
str: model name
"""
if not self._models:
raise ValueError("No models in pool")
if not self._default:
return self.get_random_name()
return self._default
def get_random(self) -> ChatLLM:
"""Get random model"""
return self._models[self.get_random_name()]
def get_default(self) -> ChatLLM:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
ChatLLM: model
"""
return self._models[self.get_default_name()]
def info(self) -> dict:
"""List all models"""
return self._info
def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
name = name.strip()
if not name:
raise ValueError("Name must not be empty")
try:
with Session(engine) as session:
if default:
# turn all models to non-default
session.query(LLMTable).update({"default": False})
session.commit()
item = LLMTable(name=name, spec=spec, default=default)
session.add(item)
session.commit()
except Exception as e:
raise ValueError(f"Failed to add model {name}: {e}")
self.load()
def delete(self, name: str):
"""Delete a model from the pool"""
try:
with Session(engine) as session:
item = session.query(LLMTable).filter_by(name=name).first()
session.delete(item)
session.commit()
except Exception as e:
raise ValueError(f"Failed to delete model {name}: {e}")
self.load()
def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool"""
if not name:
raise ValueError("Name must not be empty")
try:
with Session(engine) as session:
if default:
# turn all models to non-default
session.query(LLMTable).update({"default": False})
session.commit()
item = session.query(LLMTable).filter_by(name=name).first()
if not item:
raise ValueError(f"Model {name} not found")
item.spec = spec
item.default = default
session.commit()
except Exception as e:
raise ValueError(f"Failed to update model {name}: {e}")
self.load()
def vendors(self) -> dict:
"""Return list of vendors"""
return {vendor.__qualname__: vendor for vendor in self._vendors}
llms = LLMManager()