diff --git a/flowsettings.py b/flowsettings.py
index 0647b71..0962eef 100644
--- a/flowsettings.py
+++ b/flowsettings.py
@@ -81,6 +81,10 @@ KH_FEATURE_USER_MANAGEMENT_PASSWORD = str(
KH_ENABLE_ALEMBIC = False
KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}"
KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files")
+KH_WEB_SEARCH_BACKEND = (
+ "kotaemon.indices.retrievers.tavily_web_search.WebSearch"
+ # "kotaemon.indices.retrievers.jina_web_search.WebSearch"
+)
KH_DOCSTORE = {
# "__type__": "kotaemon.storages.ElasticsearchDocumentStore",
diff --git a/libs/kotaemon/kotaemon/indices/retrievers/__init__.py b/libs/kotaemon/kotaemon/indices/retrievers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py b/libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py
new file mode 100644
index 0000000..48fa60f
--- /dev/null
+++ b/libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py
@@ -0,0 +1,60 @@
+import requests
+from decouple import config
+
+from kotaemon.base import BaseComponent, RetrievedDocument
+
+JINA_API_KEY = config("JINA_API_KEY", default="")
+JINA_URL = config("JINA_URL", default="https://r.jina.ai/")
+
+
+class WebSearch(BaseComponent):
+ """WebSearch component for fetching data from the web
+ using Jina API
+ """
+
+ def run(
+ self,
+ text: str,
+ *args,
+ **kwargs,
+ ) -> list[RetrievedDocument]:
+ if JINA_API_KEY == "":
+ raise ValueError(
+ "This feature requires JINA_API_KEY "
+ "(get free one from https://jina.ai/reader)"
+ )
+
+ # setup the request
+ api_url = f"https://s.jina.ai/{text}"
+ headers = {"X-With-Generated-Alt": "true", "Accept": "application/json"}
+ if JINA_API_KEY:
+ headers["Authorization"] = f"Bearer {JINA_API_KEY}"
+
+ response = requests.get(api_url, headers=headers)
+ response.raise_for_status()
+ response_dict = response.json()
+
+ return [
+ RetrievedDocument(
+ text=(
+ "###URL: [{url}]({url})\n\n"
+ "####{title}\n\n"
+ "{description}\n"
+ "{content}"
+ ).format(
+ url=item["url"],
+ title=item["title"],
+ description=item["description"],
+ content=item["content"],
+ ),
+ metadata={
+ "file_name": "Web search",
+ "type": "table",
+ "llm_trulens_score": 1.0,
+ },
+ )
+ for item in response_dict["data"]
+ ]
+
+ def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
+ return documents
diff --git a/libs/kotaemon/kotaemon/indices/retrievers/tavily_web_search.py b/libs/kotaemon/kotaemon/indices/retrievers/tavily_web_search.py
new file mode 100644
index 0000000..f6087d5
--- /dev/null
+++ b/libs/kotaemon/kotaemon/indices/retrievers/tavily_web_search.py
@@ -0,0 +1,57 @@
+from decouple import config
+
+from kotaemon.base import BaseComponent, RetrievedDocument
+
+TAVILY_API_KEY = config("TAVILY_API_KEY", default="")
+
+
+class WebSearch(BaseComponent):
+ """WebSearch component for fetching data from the web
+ using Jina API
+ """
+
+ def run(
+ self,
+ text: str,
+ *args,
+ **kwargs,
+ ) -> list[RetrievedDocument]:
+ if TAVILY_API_KEY == "":
+ raise ValueError(
+ "This feature requires TAVILY_API_KEY "
+ "(get free one from https://app.tavily.com/)"
+ )
+
+ try:
+ from tavily import TavilyClient
+ except ImportError:
+ raise ImportError(
+ "Please install `pip install tavily-python` to use this feature"
+ )
+
+ tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
+ results = tavily_client.search(
+ query=text,
+ search_depth="advanced",
+ )["results"]
+ context = "\n\n".join(
+ "###URL: [{url}]({url})\n\n{content}".format(
+ url=result["url"],
+ content=result["content"],
+ )
+ for result in results
+ )
+
+ return [
+ RetrievedDocument(
+ text=context,
+ metadata={
+ "file_name": "Web search",
+ "type": "table",
+ "llm_trulens_score": 1.0,
+ },
+ )
+ ]
+
+ def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
+ return documents
diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml
index ee5c468..2e60886 100644
--- a/libs/kotaemon/pyproject.toml
+++ b/libs/kotaemon/pyproject.toml
@@ -55,6 +55,7 @@ dependencies = [
"theflow>=0.8.6,<0.9.0",
"trogon>=0.5.0,<0.6",
"umap-learn==0.5.5",
+ "tavily-python>=0.4.0",
]
readme = "README.md"
authors = [
diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py
index 417391c..aa2da36 100644
--- a/libs/ktem/ktem/index/file/ui.py
+++ b/libs/ktem/ktem/index/file/ui.py
@@ -19,6 +19,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
+from ...utils.commands import WEB_SEARCH_COMMAND
+
DOWNLOAD_MESSAGE = "Press again to download"
MAX_FILENAME_LENGTH = 20
@@ -38,6 +40,13 @@ function(file_list) {
value: '"' + file_list[i][0] + '"',
});
}
+
+ // manually push web search tag
+ values.push({
+ key: "web_search",
+ value: '"web_search"',
+ });
+
var tribute = new Tribute({
values: values,
noMatchTemplate: "",
@@ -46,7 +55,9 @@ function(file_list) {
input_box = document.querySelector('#chat-input textarea');
tribute.attach(input_box);
}
-"""
+""".replace(
+ "web_search", WEB_SEARCH_COMMAND
+)
class File(gr.File):
diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py
index 8aec594..7696027 100644
--- a/libs/ktem/ktem/pages/chat/__init__.py
+++ b/libs/ktem/ktem/pages/chat/__init__.py
@@ -1,4 +1,5 @@
import asyncio
+import importlib
import json
import re
from copy import deepcopy
@@ -23,11 +24,22 @@ from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
+from ...utils.commands import WEB_SEARCH_COMMAND
from .chat_panel import ChatPanel
from .common import STATE
from .control import ConversationControl
from .report import ReportIssue
+KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None)
+WebSearch = None
+if KH_WEB_SEARCH_BACKEND:
+ try:
+ module_name, class_name = KH_WEB_SEARCH_BACKEND.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ WebSearch = getattr(module, class_name)
+ except (ImportError, AttributeError) as e:
+ print(f"Error importing {class_name} from {module_name}: {e}")
+
DEFAULT_SETTING = "(default)"
INFO_PANEL_SCALES = {True: 8, False: 4}
@@ -113,6 +125,7 @@ class ChatPage(BasePage):
value=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False)
)
self._info_panel_expanded = gr.State(value=True)
+ self._command_state = gr.State(value=None)
def on_building_ui(self):
with gr.Row():
@@ -299,6 +312,7 @@ class ChatPage(BasePage):
# file selector from the first index
self._indices_input[0],
self._indices_input[1],
+ self._command_state,
],
concurrency_limit=20,
show_progress="hidden",
@@ -315,6 +329,7 @@ class ChatPage(BasePage):
self.citation,
self.language,
self.state_chat,
+ self._command_state,
self._app.user_id,
]
+ self._indices_input,
@@ -647,6 +662,7 @@ class ChatPage(BasePage):
chat_input_text = chat_input.get("text", "")
file_ids = []
+ used_command = None
first_selector_choices_map = {
item[0]: item[1] for item in first_selector_choices
@@ -654,6 +670,11 @@ class ChatPage(BasePage):
# get all file names with pattern @"filename" in input_str
file_names, chat_input_text = get_file_names_regex(chat_input_text)
+
+ # check if web search command is in file_names
+ if WEB_SEARCH_COMMAND in file_names:
+ used_command = WEB_SEARCH_COMMAND
+
# get all urls in input_str
urls, chat_input_text = get_urls(chat_input_text)
@@ -707,13 +728,17 @@ class ChatPage(BasePage):
conv_update = gr.update()
new_conv_name = conv_name
- return [
- {},
- chat_history,
- new_conv_id,
- conv_update,
- new_conv_name,
- ] + selector_output
+ return (
+ [
+ {},
+ chat_history,
+ new_conv_id,
+ conv_update,
+ new_conv_name,
+ ]
+ + selector_output
+ + [used_command]
+ )
def toggle_delete(self, conv_id):
if conv_id:
@@ -877,6 +902,7 @@ class ChatPage(BasePage):
session_use_citation: str,
session_language: str,
state: dict,
+ command_state: str | None,
user_id: int,
*selecteds,
):
@@ -934,17 +960,26 @@ class ChatPage(BasePage):
# get retrievers
retrievers = []
- for index in self._app.index_manager.indices:
- index_selected = []
- if isinstance(index.selector, int):
- index_selected = selecteds[index.selector]
- if isinstance(index.selector, tuple):
- for i in index.selector:
- index_selected.append(selecteds[i])
- iretrievers = index.get_retriever_pipelines(
- settings, user_id, index_selected
- )
- retrievers += iretrievers
+
+ if command_state == WEB_SEARCH_COMMAND:
+ # set retriever for web search
+ if not WebSearch:
+ raise ValueError("Web search back-end is not available.")
+
+ web_search = WebSearch()
+ retrievers.append(web_search)
+ else:
+ for index in self._app.index_manager.indices:
+ index_selected = []
+ if isinstance(index.selector, int):
+ index_selected = selecteds[index.selector]
+ if isinstance(index.selector, tuple):
+ for i in index.selector:
+ index_selected.append(selecteds[i])
+ iretrievers = index.get_retriever_pipelines(
+ settings, user_id, index_selected
+ )
+ retrievers += iretrievers
# prepare states
reasoning_state = {
@@ -966,7 +1001,8 @@ class ChatPage(BasePage):
use_mind_map,
use_citation,
language,
- state,
+ chat_state,
+ command_state,
user_id,
*selecteds,
):
@@ -976,7 +1012,7 @@ class ChatPage(BasePage):
# if chat_input is empty, assume regen mode
if chat_output:
- state["app"]["regen"] = True
+ chat_state["app"]["regen"] = True
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
@@ -988,7 +1024,8 @@ class ChatPage(BasePage):
use_mind_map,
use_citation,
language,
- state,
+ chat_state,
+ command_state,
user_id,
*selecteds,
)
@@ -1005,7 +1042,7 @@ class ChatPage(BasePage):
refs,
plot_gr,
plot,
- state,
+ chat_state,
)
for response in pipeline.stream(chat_input, conversation_id, chat_history):
@@ -1032,14 +1069,14 @@ class ChatPage(BasePage):
plot = response.content
plot_gr = self._json_to_plot(plot)
- state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
+ chat_state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield (
chat_history + [(chat_input, text or msg_placeholder)],
refs,
plot_gr,
plot,
- state,
+ chat_state,
)
if not text:
@@ -1052,7 +1089,7 @@ class ChatPage(BasePage):
refs,
plot_gr,
plot,
- state,
+ chat_state,
)
def check_and_suggest_name_conv(self, chat_history):
diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py
index 2adc52f..3db13ed 100644
--- a/libs/ktem/ktem/pages/chat/chat_panel.py
+++ b/libs/ktem/ktem/pages/chat/chat_panel.py
@@ -25,7 +25,9 @@ class ChatPanel(BasePage):
interactive=True,
scale=20,
file_count="multiple",
- placeholder="Type a message (or tag a file with @filename)",
+ placeholder=(
+ "Type a message, or search the @web, " "tag a file with @filename"
+ ),
container=False,
show_label=False,
elem_id="chat-input",
diff --git a/libs/ktem/ktem/utils/commands.py b/libs/ktem/ktem/utils/commands.py
new file mode 100644
index 0000000..48f7a70
--- /dev/null
+++ b/libs/ktem/ktem/utils/commands.py
@@ -0,0 +1 @@
+WEB_SEARCH_COMMAND = "web"
diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py
index 63e1ab5..49c2f79 100644
--- a/libs/ktem/ktem/utils/render.py
+++ b/libs/ktem/ktem/utils/render.py
@@ -59,6 +59,17 @@ class Render:
],
)
+ @staticmethod
+ def table_preserve_linebreaks(text: str) -> str:
+ """Render table from markdown format into HTML"""
+ return markdown.markdown(
+ text,
+ extensions=[
+ "markdown.extensions.tables",
+ "markdown.extensions.fenced_code",
+ ],
+ ).replace("\n", "
")
+
@staticmethod
def preview(
html_content: str,
@@ -134,6 +145,8 @@ class Render:
header = f"{get_header(doc)}"
if doc.metadata.get("type", "") == "image":
doc_content = Render.image(url=doc.metadata["image_origin"], text=doc.text)
+ elif doc.metadata.get("type", "") == "table_raw":
+ doc_content = Render.table_preserve_linebreaks(doc.text)
else:
doc_content = Render.table(doc.text)
@@ -174,6 +187,9 @@ class Render:
if item_type_prefix:
item_type_prefix += " from "
+ if "raw" in item_type_prefix:
+ item_type_prefix = ""
+
if llm_reranking_score > 0:
relevant_score = llm_reranking_score
elif reranking_score > 0:
@@ -198,6 +214,8 @@ class Render:
url=doc.metadata["image_origin"],
text=text,
)
+ elif doc.metadata.get("type", "") == "table_raw":
+ rendered_doc_content = Render.table_preserve_linebreaks(doc.text)
else:
rendered_doc_content = Render.table(text)