Provide dedicated page for login (#153)
This commit is contained in:
parent
9725d60791
commit
4f356f7f9a
|
@ -30,6 +30,8 @@ class BaseApp:
|
||||||
- Register events
|
- Register events
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
public_events: list[str] = []
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.dev_mode = getattr(settings, "KH_MODE", "") == "dev"
|
self.dev_mode = getattr(settings, "KH_MODE", "") == "dev"
|
||||||
self.f_user_management = getattr(settings, "KH_FEATURE_USER_MANAGEMENT", False)
|
self.f_user_management = getattr(settings, "KH_FEATURE_USER_MANAGEMENT", False)
|
||||||
|
@ -145,6 +147,15 @@ class BaseApp:
|
||||||
def ui(self):
|
def ui(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def on_subscribe_public_events(self):
|
||||||
|
"""Subscribe to the declared public event of the app"""
|
||||||
|
|
||||||
|
def on_register_events(self):
|
||||||
|
"""Register all events to the app"""
|
||||||
|
|
||||||
|
def _on_app_created(self):
|
||||||
|
"""Called when the app is created"""
|
||||||
|
|
||||||
def make(self):
|
def make(self):
|
||||||
with gr.Blocks(
|
with gr.Blocks(
|
||||||
theme=self._theme,
|
theme=self._theme,
|
||||||
|
@ -158,26 +169,44 @@ class BaseApp:
|
||||||
|
|
||||||
self.ui()
|
self.ui()
|
||||||
|
|
||||||
for value in self.__dict__.values():
|
self.declare_public_events()
|
||||||
if isinstance(value, BasePage):
|
self.subscribe_public_events()
|
||||||
value.declare_public_events()
|
self.register_events()
|
||||||
|
self.on_app_created()
|
||||||
for value in self.__dict__.values():
|
|
||||||
if isinstance(value, BasePage):
|
|
||||||
value.subscribe_public_events()
|
|
||||||
|
|
||||||
for value in self.__dict__.values():
|
|
||||||
if isinstance(value, BasePage):
|
|
||||||
value.register_events()
|
|
||||||
|
|
||||||
for value in self.__dict__.values():
|
|
||||||
if isinstance(value, BasePage):
|
|
||||||
value.on_app_created()
|
|
||||||
|
|
||||||
demo.load(lambda: None, None, None, js=f"() => {{{self._js}}}")
|
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
def declare_public_events(self):
|
||||||
|
"""Declare an event for the app"""
|
||||||
|
for event in self.public_events:
|
||||||
|
self.declare_event(event)
|
||||||
|
|
||||||
|
for value in self.__dict__.values():
|
||||||
|
if isinstance(value, BasePage):
|
||||||
|
value.declare_public_events()
|
||||||
|
|
||||||
|
def subscribe_public_events(self):
|
||||||
|
"""Subscribe to an event"""
|
||||||
|
self.on_subscribe_public_events()
|
||||||
|
for value in self.__dict__.values():
|
||||||
|
if isinstance(value, BasePage):
|
||||||
|
value.subscribe_public_events()
|
||||||
|
|
||||||
|
def register_events(self):
|
||||||
|
"""Register all events"""
|
||||||
|
self.on_register_events()
|
||||||
|
for value in self.__dict__.values():
|
||||||
|
if isinstance(value, BasePage):
|
||||||
|
value.register_events()
|
||||||
|
|
||||||
|
def on_app_created(self):
|
||||||
|
"""Execute on app created callbacks"""
|
||||||
|
self.app.load(lambda: None, None, None, js=f"() => {{{self._js}}}")
|
||||||
|
self._on_app_created()
|
||||||
|
for value in self.__dict__.values():
|
||||||
|
if isinstance(value, BasePage):
|
||||||
|
value.on_app_created()
|
||||||
|
|
||||||
|
|
||||||
class BasePage:
|
class BasePage:
|
||||||
"""The logic of the Kotaemon app"""
|
"""The logic of the Kotaemon app"""
|
||||||
|
|
|
@ -28,7 +28,7 @@ footer {
|
||||||
border: none !important;
|
border: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
#chat-tab, #settings-tab, #help-tab, #admin-tab {
|
#chat-tab, #settings-tab, #help-tab, #admin-tab, #login-tab {
|
||||||
border: none !important;
|
border: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,3 +17,14 @@ globalThis.clpseFn = (id) => {
|
||||||
content.style.display = "none";
|
content.style.display = "none";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// store info in local storage
|
||||||
|
globalThis.setStorage = (key, value) => {
|
||||||
|
localStorage.setItem(key, JSON.stringify(value))
|
||||||
|
}
|
||||||
|
globalThis.getStorage = (key, value) => {
|
||||||
|
return JSON.parse(localStorage.getItem(key))
|
||||||
|
}
|
||||||
|
globalThis.removeFromStorage = (key) => {
|
||||||
|
localStorage.removeItem(key)
|
||||||
|
}
|
||||||
|
|
|
@ -22,7 +22,17 @@ class App(BaseApp):
|
||||||
|
|
||||||
def ui(self):
|
def ui(self):
|
||||||
"""Render the UI"""
|
"""Render the UI"""
|
||||||
with gr.Tab("Chat", elem_id="chat-tab"):
|
self._tabs = {}
|
||||||
|
|
||||||
|
if self.f_user_management:
|
||||||
|
from ktem.pages.login import LoginPage
|
||||||
|
|
||||||
|
with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]:
|
||||||
|
self.login_page = LoginPage(self)
|
||||||
|
|
||||||
|
with gr.Tab(
|
||||||
|
"Chat", elem_id="chat-tab", visible=not self.f_user_management
|
||||||
|
) as self._tabs["chat-tab"]:
|
||||||
self.chat_page = ChatPage(self)
|
self.chat_page = ChatPage(self)
|
||||||
|
|
||||||
for index in self.index_manager.indices:
|
for index in self.index_manager.indices:
|
||||||
|
@ -30,15 +40,64 @@ class App(BaseApp):
|
||||||
f"{index.name} Index",
|
f"{index.name} Index",
|
||||||
elem_id=f"{index.id}-tab",
|
elem_id=f"{index.id}-tab",
|
||||||
elem_classes="indices-tab",
|
elem_classes="indices-tab",
|
||||||
):
|
visible=not self.f_user_management,
|
||||||
|
) as self._tabs[f"{index.id}-tab"]:
|
||||||
page = index.get_index_page_ui()
|
page = index.get_index_page_ui()
|
||||||
setattr(self, f"_index_{index.id}", page)
|
setattr(self, f"_index_{index.id}", page)
|
||||||
|
|
||||||
with gr.Tab("Admin", elem_id="admin-tab"):
|
with gr.Tab(
|
||||||
|
"Admin", elem_id="admin-tab", visible=not self.f_user_management
|
||||||
|
) as self._tabs["admin-tab"]:
|
||||||
self.admin_page = AdminPage(self)
|
self.admin_page = AdminPage(self)
|
||||||
|
|
||||||
with gr.Tab("Settings", elem_id="settings-tab"):
|
with gr.Tab(
|
||||||
|
"Settings", elem_id="settings-tab", visible=not self.f_user_management
|
||||||
|
) as self._tabs["settings-tab"]:
|
||||||
self.settings_page = SettingsPage(self)
|
self.settings_page = SettingsPage(self)
|
||||||
|
|
||||||
with gr.Tab("Help", elem_id="help-tab"):
|
with gr.Tab(
|
||||||
|
"Help", elem_id="help-tab", visible=not self.f_user_management
|
||||||
|
) as self._tabs["help-tab"]:
|
||||||
self.help_page = HelpPage(self)
|
self.help_page = HelpPage(self)
|
||||||
|
|
||||||
|
def on_subscribe_public_events(self):
|
||||||
|
if self.f_user_management:
|
||||||
|
|
||||||
|
def signed_in_out(user_id):
|
||||||
|
if not user_id:
|
||||||
|
return list(
|
||||||
|
(
|
||||||
|
gr.update(visible=True)
|
||||||
|
if k == "login-tab"
|
||||||
|
else gr.update(visible=False)
|
||||||
|
)
|
||||||
|
for k in self._tabs.keys()
|
||||||
|
)
|
||||||
|
return list(
|
||||||
|
(
|
||||||
|
gr.update(visible=True)
|
||||||
|
if k != "login-tab"
|
||||||
|
else gr.update(visible=False)
|
||||||
|
)
|
||||||
|
for k in self._tabs.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.subscribe_event(
|
||||||
|
name="onSignIn",
|
||||||
|
definition={
|
||||||
|
"fn": signed_in_out,
|
||||||
|
"inputs": [self.user_id],
|
||||||
|
"outputs": list(self._tabs.values()),
|
||||||
|
"show_progress": "hidden",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.subscribe_event(
|
||||||
|
name="onSignOut",
|
||||||
|
definition={
|
||||||
|
"fn": signed_in_out,
|
||||||
|
"inputs": [self.user_id],
|
||||||
|
"outputs": list(self._tabs.values()),
|
||||||
|
"show_progress": "hidden",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
72
libs/ktem/ktem/pages/login.py
Normal file
72
libs/ktem/ktem/pages/login.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from ktem.app import BasePage
|
||||||
|
from ktem.db.models import User, engine
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
fetch_creds = """
|
||||||
|
function() {
|
||||||
|
const username = getStorage('username')
|
||||||
|
const password = getStorage('password')
|
||||||
|
return [username, password];
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
signin_js = """
|
||||||
|
function(usn, pwd) {
|
||||||
|
setStorage('username', usn);
|
||||||
|
setStorage('password', pwd);
|
||||||
|
return [usn, pwd];
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class LoginPage(BasePage):
|
||||||
|
|
||||||
|
public_events = ["onSignIn"]
|
||||||
|
|
||||||
|
def __init__(self, app):
|
||||||
|
self._app = app
|
||||||
|
self.on_building_ui()
|
||||||
|
|
||||||
|
def on_building_ui(self):
|
||||||
|
gr.Markdown("Welcome to Kotaemon")
|
||||||
|
self.usn = gr.Textbox(label="Username")
|
||||||
|
self.pwd = gr.Textbox(label="Password", type="password")
|
||||||
|
self.btn_login = gr.Button("Login")
|
||||||
|
self._dummy = gr.State()
|
||||||
|
|
||||||
|
def on_register_events(self):
|
||||||
|
onSignIn = gr.on(
|
||||||
|
triggers=[self.btn_login.click, self.pwd.submit],
|
||||||
|
fn=self.login,
|
||||||
|
inputs=[self.usn, self.pwd],
|
||||||
|
outputs=[self._app.user_id, self.usn, self.pwd],
|
||||||
|
show_progress="hidden",
|
||||||
|
js=signin_js,
|
||||||
|
)
|
||||||
|
for event in self._app.get_event("onSignIn"):
|
||||||
|
onSignIn = onSignIn.success(**event)
|
||||||
|
|
||||||
|
def _on_app_created(self):
|
||||||
|
self._app.app.load(
|
||||||
|
None,
|
||||||
|
inputs=None,
|
||||||
|
outputs=[self.usn, self.pwd],
|
||||||
|
js=fetch_creds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def login(self, usn, pwd):
|
||||||
|
|
||||||
|
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = select(User).where(
|
||||||
|
User.username_lower == usn.lower(), User.password == hashed_password
|
||||||
|
)
|
||||||
|
result = session.exec(stmt).all()
|
||||||
|
if result:
|
||||||
|
return result[0].id, "", ""
|
||||||
|
|
||||||
|
gr.Warning("Invalid username or password")
|
||||||
|
return None, usn, pwd
|
|
@ -5,6 +5,14 @@ from ktem.app import BasePage
|
||||||
from ktem.db.models import Settings, User, engine
|
from ktem.db.models import Settings, User, engine
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
signout_js = """
|
||||||
|
function() {
|
||||||
|
removeFromStorage('username');
|
||||||
|
removeFromStorage('password');
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
gr_cls_single_value = {
|
gr_cls_single_value = {
|
||||||
"text": gr.Textbox,
|
"text": gr.Textbox,
|
||||||
"number": gr.Number,
|
"number": gr.Number,
|
||||||
|
@ -49,7 +57,7 @@ class SettingsPage(BasePage):
|
||||||
name of the setting in the `app.default_settings`
|
name of the setting in the `app.default_settings`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
public_events = ["onSignIn", "onSignOut"]
|
public_events = ["onSignOut"]
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
"""Initiate the page and render the UI"""
|
"""Initiate the page and render the UI"""
|
||||||
|
@ -79,7 +87,36 @@ class SettingsPage(BasePage):
|
||||||
self.reasoning_tab()
|
self.reasoning_tab()
|
||||||
|
|
||||||
def on_subscribe_public_events(self):
|
def on_subscribe_public_events(self):
|
||||||
pass
|
if self._app.f_user_management:
|
||||||
|
self._app.subscribe_event(
|
||||||
|
name="onSignIn",
|
||||||
|
definition={
|
||||||
|
"fn": self.load_setting,
|
||||||
|
"inputs": self._user_id,
|
||||||
|
"outputs": [self._settings_state] + self.components(),
|
||||||
|
"show_progress": "hidden",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_name(user_id):
|
||||||
|
name = "Current user: "
|
||||||
|
if user_id:
|
||||||
|
with Session(engine) as session:
|
||||||
|
statement = select(User).where(User.id == user_id)
|
||||||
|
result = session.exec(statement).all()
|
||||||
|
if result:
|
||||||
|
return name + result[0].username
|
||||||
|
return name + "___"
|
||||||
|
|
||||||
|
self._app.subscribe_event(
|
||||||
|
name="onSignIn",
|
||||||
|
definition={
|
||||||
|
"fn": get_name,
|
||||||
|
"inputs": self._user_id,
|
||||||
|
"outputs": [self.current_name],
|
||||||
|
"show_progress": "hidden",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def on_register_events(self):
|
def on_register_events(self):
|
||||||
self.setting_save_btn.click(
|
self.setting_save_btn.click(
|
||||||
|
@ -101,49 +138,15 @@ class SettingsPage(BasePage):
|
||||||
self.password_change,
|
self.password_change,
|
||||||
self.password_change_confirm,
|
self.password_change_confirm,
|
||||||
],
|
],
|
||||||
outputs=None,
|
outputs=[self.password_change, self.password_change_confirm],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
|
||||||
onSignInClick = self.signin.click(
|
|
||||||
self.sign_in,
|
|
||||||
inputs=[self.username, self.password],
|
|
||||||
outputs=[self._user_id, self.username, self.password]
|
|
||||||
+ self.signed_in_state()
|
|
||||||
+ [self.user_out_state],
|
|
||||||
show_progress="hidden",
|
|
||||||
).then(
|
|
||||||
self.load_setting,
|
|
||||||
inputs=self._user_id,
|
|
||||||
outputs=[self._settings_state] + self.components(),
|
|
||||||
show_progress="hidden",
|
|
||||||
)
|
|
||||||
for event in self._app.get_event("onSignIn"):
|
|
||||||
onSignInClick = onSignInClick.then(**event)
|
|
||||||
|
|
||||||
onSignInSubmit = self.password.submit(
|
|
||||||
self.sign_in,
|
|
||||||
inputs=[self.username, self.password],
|
|
||||||
outputs=[self._user_id, self.username, self.password]
|
|
||||||
+ self.signed_in_state()
|
|
||||||
+ [self.user_out_state],
|
|
||||||
show_progress="hidden",
|
|
||||||
).then(
|
|
||||||
self.load_setting,
|
|
||||||
inputs=self._user_id,
|
|
||||||
outputs=[self._settings_state] + self.components(),
|
|
||||||
show_progress="hidden",
|
|
||||||
)
|
|
||||||
for event in self._app.get_event("onSignIn"):
|
|
||||||
onSignInSubmit = onSignInSubmit.then(**event)
|
|
||||||
|
|
||||||
onSignOutClick = self.signout.click(
|
onSignOutClick = self.signout.click(
|
||||||
self.sign_out,
|
lambda: (None, "Current user: ___"),
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=[self._user_id]
|
outputs=[self._user_id, self.current_name],
|
||||||
+ self.signed_in_state()
|
|
||||||
+ [self.user_out_state],
|
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
js=signout_js,
|
||||||
).then(
|
).then(
|
||||||
self.load_setting,
|
self.load_setting,
|
||||||
inputs=self._user_id,
|
inputs=self._user_id,
|
||||||
|
@ -154,77 +157,22 @@ class SettingsPage(BasePage):
|
||||||
onSignOutClick = onSignOutClick.then(**event)
|
onSignOutClick = onSignOutClick.then(**event)
|
||||||
|
|
||||||
def user_tab(self):
|
def user_tab(self):
|
||||||
with gr.Column() as self.user_out_state:
|
|
||||||
gr.Markdown("Sign in")
|
|
||||||
self.username = gr.Textbox(label="Username", interactive=True)
|
|
||||||
self.password = gr.Textbox(
|
|
||||||
label="Password", type="password", interactive=True
|
|
||||||
)
|
|
||||||
self.signin = gr.Button("Login")
|
|
||||||
|
|
||||||
# user management
|
# user management
|
||||||
self.current_name = gr.Markdown("Current user: ___", visible=False)
|
self.current_name = gr.Markdown("Current user: ___")
|
||||||
self.signout = gr.Button("Logout", visible=False)
|
self.signout = gr.Button("Logout")
|
||||||
|
|
||||||
self.password_change = gr.Textbox(
|
self.password_change = gr.Textbox(
|
||||||
label="New password", interactive=True, type="password", visible=False
|
label="New password", interactive=True, type="password"
|
||||||
)
|
)
|
||||||
self.password_change_confirm = gr.Textbox(
|
self.password_change_confirm = gr.Textbox(
|
||||||
label="Confirm password", interactive=True, type="password", visible=False
|
label="Confirm password", interactive=True, type="password"
|
||||||
)
|
)
|
||||||
self.password_change_btn = gr.Button(
|
self.password_change_btn = gr.Button("Change password", interactive=True)
|
||||||
"Change password", interactive=True, visible=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def signed_in_state(self):
|
|
||||||
return [
|
|
||||||
self.current_name, # always the first one
|
|
||||||
self.signout,
|
|
||||||
self.password_change,
|
|
||||||
self.password_change_confirm,
|
|
||||||
self.password_change_btn,
|
|
||||||
]
|
|
||||||
|
|
||||||
def sign_in(self, username: str, password: str):
|
|
||||||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
|
||||||
user_id, clear_username, clear_password = None, username, password
|
|
||||||
with Session(engine) as session:
|
|
||||||
statement = select(User).where(
|
|
||||||
User.username_lower == username.lower(),
|
|
||||||
User.password == hashed_password,
|
|
||||||
)
|
|
||||||
result = session.exec(statement).all()
|
|
||||||
if result:
|
|
||||||
user_id = result[0].id
|
|
||||||
clear_username, clear_password = "", ""
|
|
||||||
else:
|
|
||||||
gr.Warning("Username or password is incorrect")
|
|
||||||
|
|
||||||
output: list = [user_id, clear_username, clear_password]
|
|
||||||
if user_id is None:
|
|
||||||
output += [
|
|
||||||
gr.update(visible=False) for _ in range(len(self.signed_in_state()))
|
|
||||||
]
|
|
||||||
output.append(gr.update(visible=True))
|
|
||||||
else:
|
|
||||||
output.append(gr.update(visible=True, value=f"Current user: {username}"))
|
|
||||||
output += [
|
|
||||||
gr.update(visible=True) for _ in range(len(self.signed_in_state()) - 1)
|
|
||||||
]
|
|
||||||
output.append(gr.update(visible=False))
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def sign_out(self):
|
|
||||||
output = [None]
|
|
||||||
output += [gr.update(visible=False) for _ in range(len(self.signed_in_state()))]
|
|
||||||
output.append(gr.update(visible=True))
|
|
||||||
return output
|
|
||||||
|
|
||||||
def change_password(self, user_id, password, password_confirm):
|
def change_password(self, user_id, password, password_confirm):
|
||||||
if password != password_confirm:
|
if password != password_confirm:
|
||||||
gr.Warning("Password does not match")
|
gr.Warning("Password does not match")
|
||||||
return
|
return password, password_confirm
|
||||||
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
statement = select(User).where(User.id == user_id)
|
statement = select(User).where(User.id == user_id)
|
||||||
|
@ -239,6 +187,8 @@ class SettingsPage(BasePage):
|
||||||
else:
|
else:
|
||||||
gr.Warning("User not found")
|
gr.Warning("User not found")
|
||||||
|
|
||||||
|
return "", ""
|
||||||
|
|
||||||
def app_tab(self):
|
def app_tab(self):
|
||||||
for n, si in self._default_settings.application.settings.items():
|
for n, si in self._default_settings.application.settings.items():
|
||||||
obj = render_setting_item(si, si.value)
|
obj = render_setting_item(si, si.value)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user