feat: add web search (#580) bump:patch

* feat: add web search

* feat: update requirements
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-12-23 09:28:24 +07:00 committed by GitHub
parent 4fe080737a
commit 95191f53d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 218 additions and 27 deletions

View File

@ -81,6 +81,10 @@ KH_FEATURE_USER_MANAGEMENT_PASSWORD = str(
KH_ENABLE_ALEMBIC = False KH_ENABLE_ALEMBIC = False
KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}" KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}"
KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files") KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files")
KH_WEB_SEARCH_BACKEND = (
"kotaemon.indices.retrievers.tavily_web_search.WebSearch"
# "kotaemon.indices.retrievers.jina_web_search.WebSearch"
)
KH_DOCSTORE = { KH_DOCSTORE = {
# "__type__": "kotaemon.storages.ElasticsearchDocumentStore", # "__type__": "kotaemon.storages.ElasticsearchDocumentStore",

View File

@ -0,0 +1,60 @@
import requests
from decouple import config
from kotaemon.base import BaseComponent, RetrievedDocument
JINA_API_KEY = config("JINA_API_KEY", default="")
JINA_URL = config("JINA_URL", default="https://r.jina.ai/")
class WebSearch(BaseComponent):
"""WebSearch component for fetching data from the web
using Jina API
"""
def run(
self,
text: str,
*args,
**kwargs,
) -> list[RetrievedDocument]:
if JINA_API_KEY == "":
raise ValueError(
"This feature requires JINA_API_KEY "
"(get free one from https://jina.ai/reader)"
)
# setup the request
api_url = f"https://s.jina.ai/{text}"
headers = {"X-With-Generated-Alt": "true", "Accept": "application/json"}
if JINA_API_KEY:
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
response = requests.get(api_url, headers=headers)
response.raise_for_status()
response_dict = response.json()
return [
RetrievedDocument(
text=(
"###URL: [{url}]({url})\n\n"
"####{title}\n\n"
"{description}\n"
"{content}"
).format(
url=item["url"],
title=item["title"],
description=item["description"],
content=item["content"],
),
metadata={
"file_name": "Web search",
"type": "table",
"llm_trulens_score": 1.0,
},
)
for item in response_dict["data"]
]
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
return documents

View File

@ -0,0 +1,57 @@
from decouple import config
from kotaemon.base import BaseComponent, RetrievedDocument
TAVILY_API_KEY = config("TAVILY_API_KEY", default="")
class WebSearch(BaseComponent):
"""WebSearch component for fetching data from the web
using Jina API
"""
def run(
self,
text: str,
*args,
**kwargs,
) -> list[RetrievedDocument]:
if TAVILY_API_KEY == "":
raise ValueError(
"This feature requires TAVILY_API_KEY "
"(get free one from https://app.tavily.com/)"
)
try:
from tavily import TavilyClient
except ImportError:
raise ImportError(
"Please install `pip install tavily-python` to use this feature"
)
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
results = tavily_client.search(
query=text,
search_depth="advanced",
)["results"]
context = "\n\n".join(
"###URL: [{url}]({url})\n\n{content}".format(
url=result["url"],
content=result["content"],
)
for result in results
)
return [
RetrievedDocument(
text=context,
metadata={
"file_name": "Web search",
"type": "table",
"llm_trulens_score": 1.0,
},
)
]
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
return documents

View File

@ -55,6 +55,7 @@ dependencies = [
"theflow>=0.8.6,<0.9.0", "theflow>=0.8.6,<0.9.0",
"trogon>=0.5.0,<0.6", "trogon>=0.5.0,<0.6",
"umap-learn==0.5.5", "umap-learn==0.5.5",
"tavily-python>=0.4.0",
] ]
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@ -19,6 +19,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
from ...utils.commands import WEB_SEARCH_COMMAND
DOWNLOAD_MESSAGE = "Press again to download" DOWNLOAD_MESSAGE = "Press again to download"
MAX_FILENAME_LENGTH = 20 MAX_FILENAME_LENGTH = 20
@ -38,6 +40,13 @@ function(file_list) {
value: '"' + file_list[i][0] + '"', value: '"' + file_list[i][0] + '"',
}); });
} }
// manually push web search tag
values.push({
key: "web_search",
value: '"web_search"',
});
var tribute = new Tribute({ var tribute = new Tribute({
values: values, values: values,
noMatchTemplate: "", noMatchTemplate: "",
@ -46,7 +55,9 @@ function(file_list) {
input_box = document.querySelector('#chat-input textarea'); input_box = document.querySelector('#chat-input textarea');
tribute.attach(input_box); tribute.attach(input_box);
} }
""" """.replace(
"web_search", WEB_SEARCH_COMMAND
)
class File(gr.File): class File(gr.File):

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import importlib
import json import json
import re import re
from copy import deepcopy from copy import deepcopy
@ -23,11 +24,22 @@ from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
from ...utils.commands import WEB_SEARCH_COMMAND
from .chat_panel import ChatPanel from .chat_panel import ChatPanel
from .common import STATE from .common import STATE
from .control import ConversationControl from .control import ConversationControl
from .report import ReportIssue from .report import ReportIssue
KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None)
WebSearch = None
if KH_WEB_SEARCH_BACKEND:
try:
module_name, class_name = KH_WEB_SEARCH_BACKEND.rsplit(".", 1)
module = importlib.import_module(module_name)
WebSearch = getattr(module, class_name)
except (ImportError, AttributeError) as e:
print(f"Error importing {class_name} from {module_name}: {e}")
DEFAULT_SETTING = "(default)" DEFAULT_SETTING = "(default)"
INFO_PANEL_SCALES = {True: 8, False: 4} INFO_PANEL_SCALES = {True: 8, False: 4}
@ -113,6 +125,7 @@ class ChatPage(BasePage):
value=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False) value=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False)
) )
self._info_panel_expanded = gr.State(value=True) self._info_panel_expanded = gr.State(value=True)
self._command_state = gr.State(value=None)
def on_building_ui(self): def on_building_ui(self):
with gr.Row(): with gr.Row():
@ -299,6 +312,7 @@ class ChatPage(BasePage):
# file selector from the first index # file selector from the first index
self._indices_input[0], self._indices_input[0],
self._indices_input[1], self._indices_input[1],
self._command_state,
], ],
concurrency_limit=20, concurrency_limit=20,
show_progress="hidden", show_progress="hidden",
@ -315,6 +329,7 @@ class ChatPage(BasePage):
self.citation, self.citation,
self.language, self.language,
self.state_chat, self.state_chat,
self._command_state,
self._app.user_id, self._app.user_id,
] ]
+ self._indices_input, + self._indices_input,
@ -647,6 +662,7 @@ class ChatPage(BasePage):
chat_input_text = chat_input.get("text", "") chat_input_text = chat_input.get("text", "")
file_ids = [] file_ids = []
used_command = None
first_selector_choices_map = { first_selector_choices_map = {
item[0]: item[1] for item in first_selector_choices item[0]: item[1] for item in first_selector_choices
@ -654,6 +670,11 @@ class ChatPage(BasePage):
# get all file names with pattern @"filename" in input_str # get all file names with pattern @"filename" in input_str
file_names, chat_input_text = get_file_names_regex(chat_input_text) file_names, chat_input_text = get_file_names_regex(chat_input_text)
# check if web search command is in file_names
if WEB_SEARCH_COMMAND in file_names:
used_command = WEB_SEARCH_COMMAND
# get all urls in input_str # get all urls in input_str
urls, chat_input_text = get_urls(chat_input_text) urls, chat_input_text = get_urls(chat_input_text)
@ -707,13 +728,17 @@ class ChatPage(BasePage):
conv_update = gr.update() conv_update = gr.update()
new_conv_name = conv_name new_conv_name = conv_name
return [ return (
{}, [
chat_history, {},
new_conv_id, chat_history,
conv_update, new_conv_id,
new_conv_name, conv_update,
] + selector_output new_conv_name,
]
+ selector_output
+ [used_command]
)
def toggle_delete(self, conv_id): def toggle_delete(self, conv_id):
if conv_id: if conv_id:
@ -877,6 +902,7 @@ class ChatPage(BasePage):
session_use_citation: str, session_use_citation: str,
session_language: str, session_language: str,
state: dict, state: dict,
command_state: str | None,
user_id: int, user_id: int,
*selecteds, *selecteds,
): ):
@ -934,17 +960,26 @@ class ChatPage(BasePage):
# get retrievers # get retrievers
retrievers = [] retrievers = []
for index in self._app.index_manager.indices:
index_selected = [] if command_state == WEB_SEARCH_COMMAND:
if isinstance(index.selector, int): # set retriever for web search
index_selected = selecteds[index.selector] if not WebSearch:
if isinstance(index.selector, tuple): raise ValueError("Web search back-end is not available.")
for i in index.selector:
index_selected.append(selecteds[i]) web_search = WebSearch()
iretrievers = index.get_retriever_pipelines( retrievers.append(web_search)
settings, user_id, index_selected else:
) for index in self._app.index_manager.indices:
retrievers += iretrievers index_selected = []
if isinstance(index.selector, int):
index_selected = selecteds[index.selector]
if isinstance(index.selector, tuple):
for i in index.selector:
index_selected.append(selecteds[i])
iretrievers = index.get_retriever_pipelines(
settings, user_id, index_selected
)
retrievers += iretrievers
# prepare states # prepare states
reasoning_state = { reasoning_state = {
@ -966,7 +1001,8 @@ class ChatPage(BasePage):
use_mind_map, use_mind_map,
use_citation, use_citation,
language, language,
state, chat_state,
command_state,
user_id, user_id,
*selecteds, *selecteds,
): ):
@ -976,7 +1012,7 @@ class ChatPage(BasePage):
# if chat_input is empty, assume regen mode # if chat_input is empty, assume regen mode
if chat_output: if chat_output:
state["app"]["regen"] = True chat_state["app"]["regen"] = True
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue() queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
@ -988,7 +1024,8 @@ class ChatPage(BasePage):
use_mind_map, use_mind_map,
use_citation, use_citation,
language, language,
state, chat_state,
command_state,
user_id, user_id,
*selecteds, *selecteds,
) )
@ -1005,7 +1042,7 @@ class ChatPage(BasePage):
refs, refs,
plot_gr, plot_gr,
plot, plot,
state, chat_state,
) )
for response in pipeline.stream(chat_input, conversation_id, chat_history): for response in pipeline.stream(chat_input, conversation_id, chat_history):
@ -1032,14 +1069,14 @@ class ChatPage(BasePage):
plot = response.content plot = response.content
plot_gr = self._json_to_plot(plot) plot_gr = self._json_to_plot(plot)
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] chat_state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield ( yield (
chat_history + [(chat_input, text or msg_placeholder)], chat_history + [(chat_input, text or msg_placeholder)],
refs, refs,
plot_gr, plot_gr,
plot, plot,
state, chat_state,
) )
if not text: if not text:
@ -1052,7 +1089,7 @@ class ChatPage(BasePage):
refs, refs,
plot_gr, plot_gr,
plot, plot,
state, chat_state,
) )
def check_and_suggest_name_conv(self, chat_history): def check_and_suggest_name_conv(self, chat_history):

View File

@ -25,7 +25,9 @@ class ChatPanel(BasePage):
interactive=True, interactive=True,
scale=20, scale=20,
file_count="multiple", file_count="multiple",
placeholder="Type a message (or tag a file with @filename)", placeholder=(
"Type a message, or search the @web, " "tag a file with @filename"
),
container=False, container=False,
show_label=False, show_label=False,
elem_id="chat-input", elem_id="chat-input",

View File

@ -0,0 +1 @@
WEB_SEARCH_COMMAND = "web"

View File

@ -59,6 +59,17 @@ class Render:
], ],
) )
@staticmethod
def table_preserve_linebreaks(text: str) -> str:
"""Render table from markdown format into HTML"""
return markdown.markdown(
text,
extensions=[
"markdown.extensions.tables",
"markdown.extensions.fenced_code",
],
).replace("\n", "<br>")
@staticmethod @staticmethod
def preview( def preview(
html_content: str, html_content: str,
@ -134,6 +145,8 @@ class Render:
header = f"<i>{get_header(doc)}</i>" header = f"<i>{get_header(doc)}</i>"
if doc.metadata.get("type", "") == "image": if doc.metadata.get("type", "") == "image":
doc_content = Render.image(url=doc.metadata["image_origin"], text=doc.text) doc_content = Render.image(url=doc.metadata["image_origin"], text=doc.text)
elif doc.metadata.get("type", "") == "table_raw":
doc_content = Render.table_preserve_linebreaks(doc.text)
else: else:
doc_content = Render.table(doc.text) doc_content = Render.table(doc.text)
@ -174,6 +187,9 @@ class Render:
if item_type_prefix: if item_type_prefix:
item_type_prefix += " from " item_type_prefix += " from "
if "raw" in item_type_prefix:
item_type_prefix = ""
if llm_reranking_score > 0: if llm_reranking_score > 0:
relevant_score = llm_reranking_score relevant_score = llm_reranking_score
elif reranking_score > 0: elif reranking_score > 0:
@ -198,6 +214,8 @@ class Render:
url=doc.metadata["image_origin"], url=doc.metadata["image_origin"],
text=text, text=text,
) )
elif doc.metadata.get("type", "") == "table_raw":
rendered_doc_content = Render.table_preserve_linebreaks(doc.text)
else: else:
rendered_doc_content = Render.table(text) rendered_doc_content = Render.table(text)