diff --git a/libs/ktem/ktem/app.py b/libs/ktem/ktem/app.py index 09ef1d6..0a39fa6 100644 --- a/libs/ktem/ktem/app.py +++ b/libs/ktem/ktem/app.py @@ -30,6 +30,8 @@ class BaseApp: - Register events """ + public_events: list[str] = [] + def __init__(self): self.dev_mode = getattr(settings, "KH_MODE", "") == "dev" self.f_user_management = getattr(settings, "KH_FEATURE_USER_MANAGEMENT", False) @@ -145,6 +147,15 @@ class BaseApp: def ui(self): raise NotImplementedError + def on_subscribe_public_events(self): + """Subscribe to the declared public event of the app""" + + def on_register_events(self): + """Register all events to the app""" + + def _on_app_created(self): + """Called when the app is created""" + def make(self): with gr.Blocks( theme=self._theme, @@ -158,26 +169,44 @@ class BaseApp: self.ui() - for value in self.__dict__.values(): - if isinstance(value, BasePage): - value.declare_public_events() - - for value in self.__dict__.values(): - if isinstance(value, BasePage): - value.subscribe_public_events() - - for value in self.__dict__.values(): - if isinstance(value, BasePage): - value.register_events() - - for value in self.__dict__.values(): - if isinstance(value, BasePage): - value.on_app_created() - - demo.load(lambda: None, None, None, js=f"() => {{{self._js}}}") + self.declare_public_events() + self.subscribe_public_events() + self.register_events() + self.on_app_created() return demo + def declare_public_events(self): + """Declare an event for the app""" + for event in self.public_events: + self.declare_event(event) + + for value in self.__dict__.values(): + if isinstance(value, BasePage): + value.declare_public_events() + + def subscribe_public_events(self): + """Subscribe to an event""" + self.on_subscribe_public_events() + for value in self.__dict__.values(): + if isinstance(value, BasePage): + value.subscribe_public_events() + + def register_events(self): + """Register all events""" + self.on_register_events() + for value in self.__dict__.values(): + if isinstance(value, BasePage): + value.register_events() + + def on_app_created(self): + """Execute on app created callbacks""" + self.app.load(lambda: None, None, None, js=f"() => {{{self._js}}}") + self._on_app_created() + for value in self.__dict__.values(): + if isinstance(value, BasePage): + value.on_app_created() + class BasePage: """The logic of the Kotaemon app""" diff --git a/libs/ktem/ktem/assets/css/main.css b/libs/ktem/ktem/assets/css/main.css index 36c3bd2..0253478 100644 --- a/libs/ktem/ktem/assets/css/main.css +++ b/libs/ktem/ktem/assets/css/main.css @@ -28,7 +28,7 @@ footer { border: none !important; } -#chat-tab, #settings-tab, #help-tab, #admin-tab { +#chat-tab, #settings-tab, #help-tab, #admin-tab, #login-tab { border: none !important; } diff --git a/libs/ktem/ktem/assets/js/main.js b/libs/ktem/ktem/assets/js/main.js index 845ba0c..9ce6933 100644 --- a/libs/ktem/ktem/assets/js/main.js +++ b/libs/ktem/ktem/assets/js/main.js @@ -17,3 +17,14 @@ globalThis.clpseFn = (id) => { content.style.display = "none"; } } + +// store info in local storage +globalThis.setStorage = (key, value) => { + localStorage.setItem(key, JSON.stringify(value)) +} +globalThis.getStorage = (key, value) => { + return JSON.parse(localStorage.getItem(key)) +} +globalThis.removeFromStorage = (key) => { + localStorage.removeItem(key) +} diff --git a/libs/ktem/ktem/main.py b/libs/ktem/ktem/main.py index 6d2e3c4..c375ed7 100644 --- a/libs/ktem/ktem/main.py +++ b/libs/ktem/ktem/main.py @@ -22,7 +22,17 @@ class App(BaseApp): def ui(self): """Render the UI""" - with gr.Tab("Chat", elem_id="chat-tab"): + self._tabs = {} + + if self.f_user_management: + from ktem.pages.login import LoginPage + + with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]: + self.login_page = LoginPage(self) + + with gr.Tab( + "Chat", elem_id="chat-tab", visible=not self.f_user_management + ) as self._tabs["chat-tab"]: self.chat_page = ChatPage(self) for index in self.index_manager.indices: @@ -30,15 +40,64 @@ class App(BaseApp): f"{index.name} Index", elem_id=f"{index.id}-tab", elem_classes="indices-tab", - ): + visible=not self.f_user_management, + ) as self._tabs[f"{index.id}-tab"]: page = index.get_index_page_ui() setattr(self, f"_index_{index.id}", page) - with gr.Tab("Admin", elem_id="admin-tab"): + with gr.Tab( + "Admin", elem_id="admin-tab", visible=not self.f_user_management + ) as self._tabs["admin-tab"]: self.admin_page = AdminPage(self) - with gr.Tab("Settings", elem_id="settings-tab"): + with gr.Tab( + "Settings", elem_id="settings-tab", visible=not self.f_user_management + ) as self._tabs["settings-tab"]: self.settings_page = SettingsPage(self) - with gr.Tab("Help", elem_id="help-tab"): + with gr.Tab( + "Help", elem_id="help-tab", visible=not self.f_user_management + ) as self._tabs["help-tab"]: self.help_page = HelpPage(self) + + def on_subscribe_public_events(self): + if self.f_user_management: + + def signed_in_out(user_id): + if not user_id: + return list( + ( + gr.update(visible=True) + if k == "login-tab" + else gr.update(visible=False) + ) + for k in self._tabs.keys() + ) + return list( + ( + gr.update(visible=True) + if k != "login-tab" + else gr.update(visible=False) + ) + for k in self._tabs.keys() + ) + + self.subscribe_event( + name="onSignIn", + definition={ + "fn": signed_in_out, + "inputs": [self.user_id], + "outputs": list(self._tabs.values()), + "show_progress": "hidden", + }, + ) + + self.subscribe_event( + name="onSignOut", + definition={ + "fn": signed_in_out, + "inputs": [self.user_id], + "outputs": list(self._tabs.values()), + "show_progress": "hidden", + }, + ) diff --git a/libs/ktem/ktem/pages/login.py b/libs/ktem/ktem/pages/login.py new file mode 100644 index 0000000..6fe15d0 --- /dev/null +++ b/libs/ktem/ktem/pages/login.py @@ -0,0 +1,72 @@ +import hashlib + +import gradio as gr +from ktem.app import BasePage +from ktem.db.models import User, engine +from sqlmodel import Session, select + +fetch_creds = """ +function() { + const username = getStorage('username') + const password = getStorage('password') + return [username, password]; +} +""" + +signin_js = """ +function(usn, pwd) { + setStorage('username', usn); + setStorage('password', pwd); + return [usn, pwd]; +} +""" + + +class LoginPage(BasePage): + + public_events = ["onSignIn"] + + def __init__(self, app): + self._app = app + self.on_building_ui() + + def on_building_ui(self): + gr.Markdown("Welcome to Kotaemon") + self.usn = gr.Textbox(label="Username") + self.pwd = gr.Textbox(label="Password", type="password") + self.btn_login = gr.Button("Login") + self._dummy = gr.State() + + def on_register_events(self): + onSignIn = gr.on( + triggers=[self.btn_login.click, self.pwd.submit], + fn=self.login, + inputs=[self.usn, self.pwd], + outputs=[self._app.user_id, self.usn, self.pwd], + show_progress="hidden", + js=signin_js, + ) + for event in self._app.get_event("onSignIn"): + onSignIn = onSignIn.success(**event) + + def _on_app_created(self): + self._app.app.load( + None, + inputs=None, + outputs=[self.usn, self.pwd], + js=fetch_creds, + ) + + def login(self, usn, pwd): + + hashed_password = hashlib.sha256(pwd.encode()).hexdigest() + with Session(engine) as session: + stmt = select(User).where( + User.username_lower == usn.lower(), User.password == hashed_password + ) + result = session.exec(stmt).all() + if result: + return result[0].id, "", "" + + gr.Warning("Invalid username or password") + return None, usn, pwd diff --git a/libs/ktem/ktem/pages/settings.py b/libs/ktem/ktem/pages/settings.py index 84c59a0..c396e33 100644 --- a/libs/ktem/ktem/pages/settings.py +++ b/libs/ktem/ktem/pages/settings.py @@ -5,6 +5,14 @@ from ktem.app import BasePage from ktem.db.models import Settings, User, engine from sqlmodel import Session, select +signout_js = """ +function() { + removeFromStorage('username'); + removeFromStorage('password'); +} +""" + + gr_cls_single_value = { "text": gr.Textbox, "number": gr.Number, @@ -49,7 +57,7 @@ class SettingsPage(BasePage): name of the setting in the `app.default_settings` """ - public_events = ["onSignIn", "onSignOut"] + public_events = ["onSignOut"] def __init__(self, app): """Initiate the page and render the UI""" @@ -79,7 +87,36 @@ class SettingsPage(BasePage): self.reasoning_tab() def on_subscribe_public_events(self): - pass + if self._app.f_user_management: + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.load_setting, + "inputs": self._user_id, + "outputs": [self._settings_state] + self.components(), + "show_progress": "hidden", + }, + ) + + def get_name(user_id): + name = "Current user: " + if user_id: + with Session(engine) as session: + statement = select(User).where(User.id == user_id) + result = session.exec(statement).all() + if result: + return name + result[0].username + return name + "___" + + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": get_name, + "inputs": self._user_id, + "outputs": [self.current_name], + "show_progress": "hidden", + }, + ) def on_register_events(self): self.setting_save_btn.click( @@ -101,49 +138,15 @@ class SettingsPage(BasePage): self.password_change, self.password_change_confirm, ], - outputs=None, + outputs=[self.password_change, self.password_change_confirm], show_progress="hidden", ) - - onSignInClick = self.signin.click( - self.sign_in, - inputs=[self.username, self.password], - outputs=[self._user_id, self.username, self.password] - + self.signed_in_state() - + [self.user_out_state], - show_progress="hidden", - ).then( - self.load_setting, - inputs=self._user_id, - outputs=[self._settings_state] + self.components(), - show_progress="hidden", - ) - for event in self._app.get_event("onSignIn"): - onSignInClick = onSignInClick.then(**event) - - onSignInSubmit = self.password.submit( - self.sign_in, - inputs=[self.username, self.password], - outputs=[self._user_id, self.username, self.password] - + self.signed_in_state() - + [self.user_out_state], - show_progress="hidden", - ).then( - self.load_setting, - inputs=self._user_id, - outputs=[self._settings_state] + self.components(), - show_progress="hidden", - ) - for event in self._app.get_event("onSignIn"): - onSignInSubmit = onSignInSubmit.then(**event) - onSignOutClick = self.signout.click( - self.sign_out, + lambda: (None, "Current user: ___"), inputs=None, - outputs=[self._user_id] - + self.signed_in_state() - + [self.user_out_state], + outputs=[self._user_id, self.current_name], show_progress="hidden", + js=signout_js, ).then( self.load_setting, inputs=self._user_id, @@ -154,77 +157,22 @@ class SettingsPage(BasePage): onSignOutClick = onSignOutClick.then(**event) def user_tab(self): - with gr.Column() as self.user_out_state: - gr.Markdown("Sign in") - self.username = gr.Textbox(label="Username", interactive=True) - self.password = gr.Textbox( - label="Password", type="password", interactive=True - ) - self.signin = gr.Button("Login") - # user management - self.current_name = gr.Markdown("Current user: ___", visible=False) - self.signout = gr.Button("Logout", visible=False) + self.current_name = gr.Markdown("Current user: ___") + self.signout = gr.Button("Logout") self.password_change = gr.Textbox( - label="New password", interactive=True, type="password", visible=False + label="New password", interactive=True, type="password" ) self.password_change_confirm = gr.Textbox( - label="Confirm password", interactive=True, type="password", visible=False + label="Confirm password", interactive=True, type="password" ) - self.password_change_btn = gr.Button( - "Change password", interactive=True, visible=False - ) - - def signed_in_state(self): - return [ - self.current_name, # always the first one - self.signout, - self.password_change, - self.password_change_confirm, - self.password_change_btn, - ] - - def sign_in(self, username: str, password: str): - hashed_password = hashlib.sha256(password.encode()).hexdigest() - user_id, clear_username, clear_password = None, username, password - with Session(engine) as session: - statement = select(User).where( - User.username_lower == username.lower(), - User.password == hashed_password, - ) - result = session.exec(statement).all() - if result: - user_id = result[0].id - clear_username, clear_password = "", "" - else: - gr.Warning("Username or password is incorrect") - - output: list = [user_id, clear_username, clear_password] - if user_id is None: - output += [ - gr.update(visible=False) for _ in range(len(self.signed_in_state())) - ] - output.append(gr.update(visible=True)) - else: - output.append(gr.update(visible=True, value=f"Current user: {username}")) - output += [ - gr.update(visible=True) for _ in range(len(self.signed_in_state()) - 1) - ] - output.append(gr.update(visible=False)) - - return output - - def sign_out(self): - output = [None] - output += [gr.update(visible=False) for _ in range(len(self.signed_in_state()))] - output.append(gr.update(visible=True)) - return output + self.password_change_btn = gr.Button("Change password", interactive=True) def change_password(self, user_id, password, password_confirm): if password != password_confirm: gr.Warning("Password does not match") - return + return password, password_confirm with Session(engine) as session: statement = select(User).where(User.id == user_id) @@ -239,6 +187,8 @@ class SettingsPage(BasePage): else: gr.Warning("User not found") + return "", "" + def app_tab(self): for n, si in self._default_settings.application.settings.items(): obj = render_setting_item(si, si.value)