feat: support for visualizing citation results (via embeddings) (#461)
* feat:support for visualizing citation results (via embeddings) Signed-off-by: Kennywu <jdlow@live.cn> * fix: remove ktem dependency in visualize_cited * fix: limit onnx version for fastembed * fix: test case of indexing * fix: minor update * fix: chroma req * fix: chroma req --------- Signed-off-by: Kennywu <jdlow@live.cn> Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
parent
bd2490bef1
commit
d127fec9f7
|
@ -53,7 +53,11 @@ class VectorIndexing(BaseIndexing):
|
||||||
def write_chunk_to_file(self, docs: list[Document]):
|
def write_chunk_to_file(self, docs: list[Document]):
|
||||||
# save the chunks content into markdown format
|
# save the chunks content into markdown format
|
||||||
if self.cache_dir:
|
if self.cache_dir:
|
||||||
file_name = Path(docs[0].metadata["file_name"])
|
file_name = docs[0].metadata.get("file_name")
|
||||||
|
if not file_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
file_name = Path(file_name)
|
||||||
for i in range(len(docs)):
|
for i in range(len(docs)):
|
||||||
markdown_content = ""
|
markdown_content = ""
|
||||||
if "page_label" in docs[i].metadata:
|
if "page_label" in docs[i].metadata:
|
||||||
|
|
|
@ -38,6 +38,7 @@ dependencies = [
|
||||||
"langchain-cohere>=0.2.4,<0.3.0",
|
"langchain-cohere>=0.2.4,<0.3.0",
|
||||||
"llama-hub>=0.0.79,<0.1.0",
|
"llama-hub>=0.0.79,<0.1.0",
|
||||||
"llama-index>=0.10.40,<0.11.0",
|
"llama-index>=0.10.40,<0.11.0",
|
||||||
|
"chromadb<=0.5.16",
|
||||||
"llama-index-vector-stores-chroma>=0.1.9",
|
"llama-index-vector-stores-chroma>=0.1.9",
|
||||||
"llama-index-vector-stores-lancedb",
|
"llama-index-vector-stores-lancedb",
|
||||||
"openai>=1.23.6,<2",
|
"openai>=1.23.6,<2",
|
||||||
|
@ -52,7 +53,8 @@ dependencies = [
|
||||||
"python-dotenv>=1.0.1,<1.1",
|
"python-dotenv>=1.0.1,<1.1",
|
||||||
"tenacity>=8.2.3,<8.3",
|
"tenacity>=8.2.3,<8.3",
|
||||||
"theflow>=0.8.6,<0.9.0",
|
"theflow>=0.8.6,<0.9.0",
|
||||||
"trogon>=0.5.0,<0.6"
|
"trogon>=0.5.0,<0.6",
|
||||||
|
"umap-learn==0.5.5",
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
@ -71,6 +73,7 @@ adv = [
|
||||||
"duckduckgo-search>=6.1.0,<6.2",
|
"duckduckgo-search>=6.1.0,<6.2",
|
||||||
"elasticsearch>=8.13.0,<8.14",
|
"elasticsearch>=8.13.0,<8.14",
|
||||||
"fastembed",
|
"fastembed",
|
||||||
|
"onnxruntime<v1.20",
|
||||||
"googlesearch-python>=1.2.4,<1.3",
|
"googlesearch-python>=1.2.4,<1.3",
|
||||||
"llama-cpp-python<0.2.8",
|
"llama-cpp-python<0.2.8",
|
||||||
"llama-index>=0.10.40,<0.11.0",
|
"llama-index>=0.10.40,<0.11.0",
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Generator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
||||||
from ktem.llms.manager import llms
|
from ktem.llms.manager import llms
|
||||||
from ktem.reasoning.prompt_optimization import (
|
from ktem.reasoning.prompt_optimization import (
|
||||||
CreateMindmapPipeline,
|
CreateMindmapPipeline,
|
||||||
|
@ -16,6 +17,8 @@ from ktem.reasoning.prompt_optimization import (
|
||||||
)
|
)
|
||||||
from ktem.utils.plantuml import PlantUML
|
from ktem.utils.plantuml import PlantUML
|
||||||
from ktem.utils.render import Render
|
from ktem.utils.render import Render
|
||||||
|
from ktem.utils.visualize_cited import CreateCitationVizPipeline
|
||||||
|
from plotly.io import to_json
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
from kotaemon.base import (
|
from kotaemon.base import (
|
||||||
|
@ -240,6 +243,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
|
|
||||||
enable_citation: bool = False
|
enable_citation: bool = False
|
||||||
enable_mindmap: bool = False
|
enable_mindmap: bool = False
|
||||||
|
enable_citation_viz: bool = False
|
||||||
|
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
lang: str = "English" # support English and Japanese
|
lang: str = "English" # support English and Japanese
|
||||||
|
@ -409,7 +413,12 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
|
|
||||||
answer = Document(
|
answer = Document(
|
||||||
text=output,
|
text=output,
|
||||||
metadata={"mindmap": mindmap, "citation": citation, "qa_score": qa_score},
|
metadata={
|
||||||
|
"citation_viz": self.enable_citation_viz,
|
||||||
|
"mindmap": mindmap,
|
||||||
|
"citation": citation,
|
||||||
|
"qa_score": qa_score,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
@ -474,6 +483,11 @@ class FullQAPipeline(BaseReasoning):
|
||||||
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
||||||
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
||||||
rewrite_pipeline: RewriteQuestionPipeline | None = None
|
rewrite_pipeline: RewriteQuestionPipeline | None = None
|
||||||
|
create_citation_viz_pipeline: CreateCitationVizPipeline = Node(
|
||||||
|
default_callback=lambda _: CreateCitationVizPipeline(
|
||||||
|
embedding=embeddings.get_default()
|
||||||
|
)
|
||||||
|
)
|
||||||
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()
|
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()
|
||||||
|
|
||||||
def retrieve(
|
def retrieve(
|
||||||
|
@ -641,10 +655,28 @@ class FullQAPipeline(BaseReasoning):
|
||||||
|
|
||||||
return mindmap_content
|
return mindmap_content
|
||||||
|
|
||||||
def show_citations_and_addons(self, answer, docs):
|
def prepare_citation_viz(self, answer, question, docs) -> Document | None:
|
||||||
|
doc_texts = [doc.text for doc in docs]
|
||||||
|
citation_plot = None
|
||||||
|
plot_content = None
|
||||||
|
|
||||||
|
if answer.metadata["citation_viz"] and len(docs) > 1:
|
||||||
|
try:
|
||||||
|
citation_plot = self.create_citation_viz_pipeline(doc_texts, question)
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to create citation plot:", e)
|
||||||
|
|
||||||
|
if citation_plot:
|
||||||
|
plot = to_json(citation_plot)
|
||||||
|
plot_content = Document(channel="plot", content=plot)
|
||||||
|
|
||||||
|
return plot_content
|
||||||
|
|
||||||
|
def show_citations_and_addons(self, answer, docs, question):
|
||||||
# show the evidence
|
# show the evidence
|
||||||
with_citation, without_citation = self.prepare_citations(answer, docs)
|
with_citation, without_citation = self.prepare_citations(answer, docs)
|
||||||
mindmap_output = self.prepare_mindmap(answer)
|
mindmap_output = self.prepare_mindmap(answer)
|
||||||
|
citation_plot_output = self.prepare_citation_viz(answer, question, docs)
|
||||||
|
|
||||||
if not with_citation and not without_citation:
|
if not with_citation and not without_citation:
|
||||||
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
|
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
|
||||||
|
@ -661,6 +693,10 @@ class FullQAPipeline(BaseReasoning):
|
||||||
if mindmap_output:
|
if mindmap_output:
|
||||||
yield mindmap_output
|
yield mindmap_output
|
||||||
|
|
||||||
|
# yield citation plot output
|
||||||
|
if citation_plot_output:
|
||||||
|
yield citation_plot_output
|
||||||
|
|
||||||
# yield warning message
|
# yield warning message
|
||||||
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
|
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
|
||||||
yield Document(
|
yield Document(
|
||||||
|
@ -733,7 +769,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
if scoring_thread:
|
if scoring_thread:
|
||||||
scoring_thread.join()
|
scoring_thread.join()
|
||||||
|
|
||||||
yield from self.show_citations_and_addons(answer, docs)
|
yield from self.show_citations_and_addons(answer, docs, message)
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
@ -767,6 +803,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
||||||
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
||||||
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
|
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
|
||||||
|
answer_pipeline.enable_citation_viz = settings[f"{prefix}.create_citation_viz"]
|
||||||
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
||||||
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
|
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
|
||||||
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
||||||
|
@ -820,6 +857,11 @@ class FullQAPipeline(BaseReasoning):
|
||||||
"value": False,
|
"value": False,
|
||||||
"component": "checkbox",
|
"component": "checkbox",
|
||||||
},
|
},
|
||||||
|
"create_citation_viz": {
|
||||||
|
"name": "Create Embeddings Visualization",
|
||||||
|
"value": False,
|
||||||
|
"component": "checkbox",
|
||||||
|
},
|
||||||
"system_prompt": {
|
"system_prompt": {
|
||||||
"name": "System Prompt",
|
"name": "System Prompt",
|
||||||
"value": "This is a question answering system",
|
"value": "This is a question answering system",
|
||||||
|
|
142
libs/ktem/ktem/utils/visualize_cited.py
Normal file
142
libs/ktem/ktem/utils/visualize_cited.py
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
"""
|
||||||
|
This module aims to project high-dimensional embeddings
|
||||||
|
into a lower-dimensional space for visualization.
|
||||||
|
|
||||||
|
Refs:
|
||||||
|
1. [RAGxplorer](https://github.com/gabrielchua/RAGxplorer)
|
||||||
|
2. [RAGVizExpander](https://github.com/KKenny0/RAGVizExpander)
|
||||||
|
"""
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import umap
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.embeddings import BaseEmbeddings
|
||||||
|
|
||||||
|
VISUALIZATION_SETTINGS = {
|
||||||
|
"Original Query": {"color": "red", "opacity": 1, "symbol": "cross", "size": 15},
|
||||||
|
"Retrieved": {"color": "green", "opacity": 1, "symbol": "circle", "size": 10},
|
||||||
|
"Chunks": {"color": "blue", "opacity": 0.4, "symbol": "circle", "size": 10},
|
||||||
|
"Sub-Questions": {"color": "purple", "opacity": 1, "symbol": "star", "size": 15},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CreateCitationVizPipeline(BaseComponent):
|
||||||
|
"""Creating PlotData for visualizing query results"""
|
||||||
|
|
||||||
|
embedding: BaseEmbeddings
|
||||||
|
projector: umap.UMAP = None
|
||||||
|
|
||||||
|
def _set_up_umap(self, embeddings: np.ndarray):
|
||||||
|
umap_transform = umap.UMAP().fit(embeddings)
|
||||||
|
return umap_transform
|
||||||
|
|
||||||
|
def _project_embeddings(self, embeddings, umap_transform) -> np.ndarray:
|
||||||
|
umap_embeddings = np.empty((len(embeddings), 2))
|
||||||
|
for i, embedding in enumerate(embeddings):
|
||||||
|
umap_embeddings[i] = umap_transform.transform([embedding])
|
||||||
|
return umap_embeddings
|
||||||
|
|
||||||
|
def _get_projections(self, embeddings, umap_transform):
|
||||||
|
projections = self._project_embeddings(embeddings, umap_transform)
|
||||||
|
x = projections[:, 0]
|
||||||
|
y = projections[:, 1]
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def _prepare_projection_df(
|
||||||
|
self,
|
||||||
|
document_projections: Tuple[np.ndarray, np.ndarray],
|
||||||
|
document_text: List[str],
|
||||||
|
plot_size: int = 3,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Prepares a DataFrame for visualization from projections and texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_projections (Tuple[np.ndarray, np.ndarray]):
|
||||||
|
Tuple of X and Y coordinates of document projections.
|
||||||
|
document_text (List[str]): List of document texts.
|
||||||
|
"""
|
||||||
|
df = pd.DataFrame({"x": document_projections[0], "y": document_projections[1]})
|
||||||
|
df["document"] = document_text
|
||||||
|
df["document_cleaned"] = df.document.str.wrap(50).apply(
|
||||||
|
lambda x: x.replace("\n", "<br>")[:512] + "..."
|
||||||
|
)
|
||||||
|
df["size"] = plot_size
|
||||||
|
df["category"] = "Retrieved"
|
||||||
|
return df
|
||||||
|
|
||||||
|
def _plot_embeddings(self, df: pd.DataFrame) -> go.Figure:
|
||||||
|
"""
|
||||||
|
Creates a Plotly figure to visualize the embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): DataFrame containing the data to visualize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
go.Figure: A Plotly figure object for visualization.
|
||||||
|
"""
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
for category in df["category"].unique():
|
||||||
|
category_df = df[df["category"] == category]
|
||||||
|
settings = VISUALIZATION_SETTINGS.get(
|
||||||
|
category,
|
||||||
|
{"color": "grey", "opacity": 1, "symbol": "circle", "size": 10},
|
||||||
|
)
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=category_df["x"],
|
||||||
|
y=category_df["y"],
|
||||||
|
mode="markers",
|
||||||
|
name=category,
|
||||||
|
marker=dict(
|
||||||
|
color=settings["color"],
|
||||||
|
opacity=settings["opacity"],
|
||||||
|
symbol=settings["symbol"],
|
||||||
|
size=settings["size"],
|
||||||
|
line_width=0,
|
||||||
|
),
|
||||||
|
hoverinfo="text",
|
||||||
|
text=category_df["document_cleaned"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
height=500,
|
||||||
|
legend=dict(y=100, x=0.5, xanchor="center", yanchor="top", orientation="h"),
|
||||||
|
)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def run(self, context: List[str], question: str):
|
||||||
|
embed_contexts = self.embedding(context)
|
||||||
|
context_embeddings = np.array([d.embedding for d in embed_contexts])
|
||||||
|
|
||||||
|
self.projector = self._set_up_umap(embeddings=context_embeddings)
|
||||||
|
|
||||||
|
embed_query = self.embedding(question)
|
||||||
|
query_projection = self._get_projections(
|
||||||
|
embeddings=[embed_query[0].embedding], umap_transform=self.projector
|
||||||
|
)
|
||||||
|
viz_query_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"x": [query_projection[0][0]],
|
||||||
|
"y": [query_projection[1][0]],
|
||||||
|
"document_cleaned": question,
|
||||||
|
"category": "Original Query",
|
||||||
|
"size": 5,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
context_projections = self._get_projections(
|
||||||
|
embeddings=context_embeddings, umap_transform=self.projector
|
||||||
|
)
|
||||||
|
viz_base_df = self._prepare_projection_df(
|
||||||
|
document_projections=context_projections, document_text=context
|
||||||
|
)
|
||||||
|
|
||||||
|
visualization_df = pd.concat([viz_base_df, viz_query_df], axis=0)
|
||||||
|
fig = self._plot_embeddings(visualization_df)
|
||||||
|
return fig
|
Loading…
Reference in New Issue
Block a user