feat: add web search (#580) bump:patch
* feat: add web search * feat: update requirements
This commit is contained in:
parent
4fe080737a
commit
95191f53d9
|
@ -81,6 +81,10 @@ KH_FEATURE_USER_MANAGEMENT_PASSWORD = str(
|
||||||
KH_ENABLE_ALEMBIC = False
|
KH_ENABLE_ALEMBIC = False
|
||||||
KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}"
|
KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}"
|
||||||
KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files")
|
KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files")
|
||||||
|
KH_WEB_SEARCH_BACKEND = (
|
||||||
|
"kotaemon.indices.retrievers.tavily_web_search.WebSearch"
|
||||||
|
# "kotaemon.indices.retrievers.jina_web_search.WebSearch"
|
||||||
|
)
|
||||||
|
|
||||||
KH_DOCSTORE = {
|
KH_DOCSTORE = {
|
||||||
# "__type__": "kotaemon.storages.ElasticsearchDocumentStore",
|
# "__type__": "kotaemon.storages.ElasticsearchDocumentStore",
|
||||||
|
|
60
libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py
Normal file
60
libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import requests
|
||||||
|
from decouple import config
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent, RetrievedDocument
|
||||||
|
|
||||||
|
JINA_API_KEY = config("JINA_API_KEY", default="")
|
||||||
|
JINA_URL = config("JINA_URL", default="https://r.jina.ai/")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearch(BaseComponent):
|
||||||
|
"""WebSearch component for fetching data from the web
|
||||||
|
using Jina API
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[RetrievedDocument]:
|
||||||
|
if JINA_API_KEY == "":
|
||||||
|
raise ValueError(
|
||||||
|
"This feature requires JINA_API_KEY "
|
||||||
|
"(get free one from https://jina.ai/reader)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup the request
|
||||||
|
api_url = f"https://s.jina.ai/{text}"
|
||||||
|
headers = {"X-With-Generated-Alt": "true", "Accept": "application/json"}
|
||||||
|
if JINA_API_KEY:
|
||||||
|
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
|
||||||
|
|
||||||
|
response = requests.get(api_url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
response_dict = response.json()
|
||||||
|
|
||||||
|
return [
|
||||||
|
RetrievedDocument(
|
||||||
|
text=(
|
||||||
|
"###URL: [{url}]({url})\n\n"
|
||||||
|
"####{title}\n\n"
|
||||||
|
"{description}\n"
|
||||||
|
"{content}"
|
||||||
|
).format(
|
||||||
|
url=item["url"],
|
||||||
|
title=item["title"],
|
||||||
|
description=item["description"],
|
||||||
|
content=item["content"],
|
||||||
|
),
|
||||||
|
metadata={
|
||||||
|
"file_name": "Web search",
|
||||||
|
"type": "table",
|
||||||
|
"llm_trulens_score": 1.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for item in response_dict["data"]
|
||||||
|
]
|
||||||
|
|
||||||
|
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
|
||||||
|
return documents
|
|
@ -0,0 +1,57 @@
|
||||||
|
from decouple import config
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent, RetrievedDocument
|
||||||
|
|
||||||
|
TAVILY_API_KEY = config("TAVILY_API_KEY", default="")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearch(BaseComponent):
|
||||||
|
"""WebSearch component for fetching data from the web
|
||||||
|
using Jina API
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[RetrievedDocument]:
|
||||||
|
if TAVILY_API_KEY == "":
|
||||||
|
raise ValueError(
|
||||||
|
"This feature requires TAVILY_API_KEY "
|
||||||
|
"(get free one from https://app.tavily.com/)"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tavily import TavilyClient
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install `pip install tavily-python` to use this feature"
|
||||||
|
)
|
||||||
|
|
||||||
|
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
|
||||||
|
results = tavily_client.search(
|
||||||
|
query=text,
|
||||||
|
search_depth="advanced",
|
||||||
|
)["results"]
|
||||||
|
context = "\n\n".join(
|
||||||
|
"###URL: [{url}]({url})\n\n{content}".format(
|
||||||
|
url=result["url"],
|
||||||
|
content=result["content"],
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
RetrievedDocument(
|
||||||
|
text=context,
|
||||||
|
metadata={
|
||||||
|
"file_name": "Web search",
|
||||||
|
"type": "table",
|
||||||
|
"llm_trulens_score": 1.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
|
||||||
|
return documents
|
|
@ -55,6 +55,7 @@ dependencies = [
|
||||||
"theflow>=0.8.6,<0.9.0",
|
"theflow>=0.8.6,<0.9.0",
|
||||||
"trogon>=0.5.0,<0.6",
|
"trogon>=0.5.0,<0.6",
|
||||||
"umap-learn==0.5.5",
|
"umap-learn==0.5.5",
|
||||||
|
"tavily-python>=0.4.0",
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
|
|
@ -19,6 +19,8 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
|
from ...utils.commands import WEB_SEARCH_COMMAND
|
||||||
|
|
||||||
DOWNLOAD_MESSAGE = "Press again to download"
|
DOWNLOAD_MESSAGE = "Press again to download"
|
||||||
MAX_FILENAME_LENGTH = 20
|
MAX_FILENAME_LENGTH = 20
|
||||||
|
|
||||||
|
@ -38,6 +40,13 @@ function(file_list) {
|
||||||
value: '"' + file_list[i][0] + '"',
|
value: '"' + file_list[i][0] + '"',
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// manually push web search tag
|
||||||
|
values.push({
|
||||||
|
key: "web_search",
|
||||||
|
value: '"web_search"',
|
||||||
|
});
|
||||||
|
|
||||||
var tribute = new Tribute({
|
var tribute = new Tribute({
|
||||||
values: values,
|
values: values,
|
||||||
noMatchTemplate: "",
|
noMatchTemplate: "",
|
||||||
|
@ -46,7 +55,9 @@ function(file_list) {
|
||||||
input_box = document.querySelector('#chat-input textarea');
|
input_box = document.querySelector('#chat-input textarea');
|
||||||
tribute.attach(input_box);
|
tribute.attach(input_box);
|
||||||
}
|
}
|
||||||
"""
|
""".replace(
|
||||||
|
"web_search", WEB_SEARCH_COMMAND
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class File(gr.File):
|
class File(gr.File):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -23,11 +24,22 @@ from kotaemon.base import Document
|
||||||
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||||
|
|
||||||
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
|
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
|
||||||
|
from ...utils.commands import WEB_SEARCH_COMMAND
|
||||||
from .chat_panel import ChatPanel
|
from .chat_panel import ChatPanel
|
||||||
from .common import STATE
|
from .common import STATE
|
||||||
from .control import ConversationControl
|
from .control import ConversationControl
|
||||||
from .report import ReportIssue
|
from .report import ReportIssue
|
||||||
|
|
||||||
|
KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None)
|
||||||
|
WebSearch = None
|
||||||
|
if KH_WEB_SEARCH_BACKEND:
|
||||||
|
try:
|
||||||
|
module_name, class_name = KH_WEB_SEARCH_BACKEND.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
WebSearch = getattr(module, class_name)
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
print(f"Error importing {class_name} from {module_name}: {e}")
|
||||||
|
|
||||||
DEFAULT_SETTING = "(default)"
|
DEFAULT_SETTING = "(default)"
|
||||||
INFO_PANEL_SCALES = {True: 8, False: 4}
|
INFO_PANEL_SCALES = {True: 8, False: 4}
|
||||||
|
|
||||||
|
@ -113,6 +125,7 @@ class ChatPage(BasePage):
|
||||||
value=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False)
|
value=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False)
|
||||||
)
|
)
|
||||||
self._info_panel_expanded = gr.State(value=True)
|
self._info_panel_expanded = gr.State(value=True)
|
||||||
|
self._command_state = gr.State(value=None)
|
||||||
|
|
||||||
def on_building_ui(self):
|
def on_building_ui(self):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -299,6 +312,7 @@ class ChatPage(BasePage):
|
||||||
# file selector from the first index
|
# file selector from the first index
|
||||||
self._indices_input[0],
|
self._indices_input[0],
|
||||||
self._indices_input[1],
|
self._indices_input[1],
|
||||||
|
self._command_state,
|
||||||
],
|
],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
@ -315,6 +329,7 @@ class ChatPage(BasePage):
|
||||||
self.citation,
|
self.citation,
|
||||||
self.language,
|
self.language,
|
||||||
self.state_chat,
|
self.state_chat,
|
||||||
|
self._command_state,
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
|
@ -647,6 +662,7 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
chat_input_text = chat_input.get("text", "")
|
chat_input_text = chat_input.get("text", "")
|
||||||
file_ids = []
|
file_ids = []
|
||||||
|
used_command = None
|
||||||
|
|
||||||
first_selector_choices_map = {
|
first_selector_choices_map = {
|
||||||
item[0]: item[1] for item in first_selector_choices
|
item[0]: item[1] for item in first_selector_choices
|
||||||
|
@ -654,6 +670,11 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
# get all file names with pattern @"filename" in input_str
|
# get all file names with pattern @"filename" in input_str
|
||||||
file_names, chat_input_text = get_file_names_regex(chat_input_text)
|
file_names, chat_input_text = get_file_names_regex(chat_input_text)
|
||||||
|
|
||||||
|
# check if web search command is in file_names
|
||||||
|
if WEB_SEARCH_COMMAND in file_names:
|
||||||
|
used_command = WEB_SEARCH_COMMAND
|
||||||
|
|
||||||
# get all urls in input_str
|
# get all urls in input_str
|
||||||
urls, chat_input_text = get_urls(chat_input_text)
|
urls, chat_input_text = get_urls(chat_input_text)
|
||||||
|
|
||||||
|
@ -707,13 +728,17 @@ class ChatPage(BasePage):
|
||||||
conv_update = gr.update()
|
conv_update = gr.update()
|
||||||
new_conv_name = conv_name
|
new_conv_name = conv_name
|
||||||
|
|
||||||
return [
|
return (
|
||||||
{},
|
[
|
||||||
chat_history,
|
{},
|
||||||
new_conv_id,
|
chat_history,
|
||||||
conv_update,
|
new_conv_id,
|
||||||
new_conv_name,
|
conv_update,
|
||||||
] + selector_output
|
new_conv_name,
|
||||||
|
]
|
||||||
|
+ selector_output
|
||||||
|
+ [used_command]
|
||||||
|
)
|
||||||
|
|
||||||
def toggle_delete(self, conv_id):
|
def toggle_delete(self, conv_id):
|
||||||
if conv_id:
|
if conv_id:
|
||||||
|
@ -877,6 +902,7 @@ class ChatPage(BasePage):
|
||||||
session_use_citation: str,
|
session_use_citation: str,
|
||||||
session_language: str,
|
session_language: str,
|
||||||
state: dict,
|
state: dict,
|
||||||
|
command_state: str | None,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
*selecteds,
|
*selecteds,
|
||||||
):
|
):
|
||||||
|
@ -934,17 +960,26 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
# get retrievers
|
# get retrievers
|
||||||
retrievers = []
|
retrievers = []
|
||||||
for index in self._app.index_manager.indices:
|
|
||||||
index_selected = []
|
if command_state == WEB_SEARCH_COMMAND:
|
||||||
if isinstance(index.selector, int):
|
# set retriever for web search
|
||||||
index_selected = selecteds[index.selector]
|
if not WebSearch:
|
||||||
if isinstance(index.selector, tuple):
|
raise ValueError("Web search back-end is not available.")
|
||||||
for i in index.selector:
|
|
||||||
index_selected.append(selecteds[i])
|
web_search = WebSearch()
|
||||||
iretrievers = index.get_retriever_pipelines(
|
retrievers.append(web_search)
|
||||||
settings, user_id, index_selected
|
else:
|
||||||
)
|
for index in self._app.index_manager.indices:
|
||||||
retrievers += iretrievers
|
index_selected = []
|
||||||
|
if isinstance(index.selector, int):
|
||||||
|
index_selected = selecteds[index.selector]
|
||||||
|
if isinstance(index.selector, tuple):
|
||||||
|
for i in index.selector:
|
||||||
|
index_selected.append(selecteds[i])
|
||||||
|
iretrievers = index.get_retriever_pipelines(
|
||||||
|
settings, user_id, index_selected
|
||||||
|
)
|
||||||
|
retrievers += iretrievers
|
||||||
|
|
||||||
# prepare states
|
# prepare states
|
||||||
reasoning_state = {
|
reasoning_state = {
|
||||||
|
@ -966,7 +1001,8 @@ class ChatPage(BasePage):
|
||||||
use_mind_map,
|
use_mind_map,
|
||||||
use_citation,
|
use_citation,
|
||||||
language,
|
language,
|
||||||
state,
|
chat_state,
|
||||||
|
command_state,
|
||||||
user_id,
|
user_id,
|
||||||
*selecteds,
|
*selecteds,
|
||||||
):
|
):
|
||||||
|
@ -976,7 +1012,7 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
# if chat_input is empty, assume regen mode
|
# if chat_input is empty, assume regen mode
|
||||||
if chat_output:
|
if chat_output:
|
||||||
state["app"]["regen"] = True
|
chat_state["app"]["regen"] = True
|
||||||
|
|
||||||
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
||||||
|
|
||||||
|
@ -988,7 +1024,8 @@ class ChatPage(BasePage):
|
||||||
use_mind_map,
|
use_mind_map,
|
||||||
use_citation,
|
use_citation,
|
||||||
language,
|
language,
|
||||||
state,
|
chat_state,
|
||||||
|
command_state,
|
||||||
user_id,
|
user_id,
|
||||||
*selecteds,
|
*selecteds,
|
||||||
)
|
)
|
||||||
|
@ -1005,7 +1042,7 @@ class ChatPage(BasePage):
|
||||||
refs,
|
refs,
|
||||||
plot_gr,
|
plot_gr,
|
||||||
plot,
|
plot,
|
||||||
state,
|
chat_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
for response in pipeline.stream(chat_input, conversation_id, chat_history):
|
for response in pipeline.stream(chat_input, conversation_id, chat_history):
|
||||||
|
@ -1032,14 +1069,14 @@ class ChatPage(BasePage):
|
||||||
plot = response.content
|
plot = response.content
|
||||||
plot_gr = self._json_to_plot(plot)
|
plot_gr = self._json_to_plot(plot)
|
||||||
|
|
||||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
chat_state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||||
|
|
||||||
yield (
|
yield (
|
||||||
chat_history + [(chat_input, text or msg_placeholder)],
|
chat_history + [(chat_input, text or msg_placeholder)],
|
||||||
refs,
|
refs,
|
||||||
plot_gr,
|
plot_gr,
|
||||||
plot,
|
plot,
|
||||||
state,
|
chat_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not text:
|
if not text:
|
||||||
|
@ -1052,7 +1089,7 @@ class ChatPage(BasePage):
|
||||||
refs,
|
refs,
|
||||||
plot_gr,
|
plot_gr,
|
||||||
plot,
|
plot,
|
||||||
state,
|
chat_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_and_suggest_name_conv(self, chat_history):
|
def check_and_suggest_name_conv(self, chat_history):
|
||||||
|
|
|
@ -25,7 +25,9 @@ class ChatPanel(BasePage):
|
||||||
interactive=True,
|
interactive=True,
|
||||||
scale=20,
|
scale=20,
|
||||||
file_count="multiple",
|
file_count="multiple",
|
||||||
placeholder="Type a message (or tag a file with @filename)",
|
placeholder=(
|
||||||
|
"Type a message, or search the @web, " "tag a file with @filename"
|
||||||
|
),
|
||||||
container=False,
|
container=False,
|
||||||
show_label=False,
|
show_label=False,
|
||||||
elem_id="chat-input",
|
elem_id="chat-input",
|
||||||
|
|
1
libs/ktem/ktem/utils/commands.py
Normal file
1
libs/ktem/ktem/utils/commands.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
WEB_SEARCH_COMMAND = "web"
|
|
@ -59,6 +59,17 @@ class Render:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def table_preserve_linebreaks(text: str) -> str:
|
||||||
|
"""Render table from markdown format into HTML"""
|
||||||
|
return markdown.markdown(
|
||||||
|
text,
|
||||||
|
extensions=[
|
||||||
|
"markdown.extensions.tables",
|
||||||
|
"markdown.extensions.fenced_code",
|
||||||
|
],
|
||||||
|
).replace("\n", "<br>")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def preview(
|
def preview(
|
||||||
html_content: str,
|
html_content: str,
|
||||||
|
@ -134,6 +145,8 @@ class Render:
|
||||||
header = f"<i>{get_header(doc)}</i>"
|
header = f"<i>{get_header(doc)}</i>"
|
||||||
if doc.metadata.get("type", "") == "image":
|
if doc.metadata.get("type", "") == "image":
|
||||||
doc_content = Render.image(url=doc.metadata["image_origin"], text=doc.text)
|
doc_content = Render.image(url=doc.metadata["image_origin"], text=doc.text)
|
||||||
|
elif doc.metadata.get("type", "") == "table_raw":
|
||||||
|
doc_content = Render.table_preserve_linebreaks(doc.text)
|
||||||
else:
|
else:
|
||||||
doc_content = Render.table(doc.text)
|
doc_content = Render.table(doc.text)
|
||||||
|
|
||||||
|
@ -174,6 +187,9 @@ class Render:
|
||||||
if item_type_prefix:
|
if item_type_prefix:
|
||||||
item_type_prefix += " from "
|
item_type_prefix += " from "
|
||||||
|
|
||||||
|
if "raw" in item_type_prefix:
|
||||||
|
item_type_prefix = ""
|
||||||
|
|
||||||
if llm_reranking_score > 0:
|
if llm_reranking_score > 0:
|
||||||
relevant_score = llm_reranking_score
|
relevant_score = llm_reranking_score
|
||||||
elif reranking_score > 0:
|
elif reranking_score > 0:
|
||||||
|
@ -198,6 +214,8 @@ class Render:
|
||||||
url=doc.metadata["image_origin"],
|
url=doc.metadata["image_origin"],
|
||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
elif doc.metadata.get("type", "") == "table_raw":
|
||||||
|
rendered_doc_content = Render.table_preserve_linebreaks(doc.text)
|
||||||
else:
|
else:
|
||||||
rendered_doc_content = Render.table(text)
|
rendered_doc_content = Render.table(text)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user