kotaemon/libs/ktem/ktem/index/file/graph/pipelines.py
Tuan Anh Nguyen Dang (Tadashi_Cin) 7a02cb72af
feat: add URL indexing directly from chat input (#571) bump:patch
* feat: enable lightrag by default and add graphrag key check

* feat: add URL indexing from chatbox
2024-12-15 23:13:52 +07:00

403 lines
14 KiB
Python

import os
import shutil
import subprocess
from pathlib import Path
from shutil import rmtree
from typing import Generator
from uuid import uuid4
import pandas as pd
import tiktoken
import yaml
from decouple import config
from ktem.db.models import engine
from sqlalchemy.orm import Session
from theflow.settings import settings
from kotaemon.base import Document, Param, RetrievedDocument
from ..pipelines import BaseFileIndexRetriever, IndexDocumentPipeline, IndexPipeline
from .visualize import create_knowledge_graph, visualize_graph
try:
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.vector_stores.lancedb import LanceDBVectorStore
except ImportError:
print(
(
"GraphRAG dependencies not installed. "
"Try `pip install graphrag future` to install. "
"GraphRAG retriever pipeline will not work properly."
)
)
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "graphrag"
filestorage_path.mkdir(parents=True, exist_ok=True)
GRAPHRAG_KEY_MISSING_MESSAGE = (
"GRAPHRAG_API_KEY is not set. Please set it to use the GraphRAG retriever pipeline."
)
def check_graphrag_api_key():
return len(os.getenv("GRAPHRAG_API_KEY", "")) > 0
def prepare_graph_index_path(graph_id: str):
root_path = Path(filestorage_path) / graph_id
input_path = root_path / "input"
return root_path, input_path
class GraphRAGIndexingPipeline(IndexDocumentPipeline):
"""GraphRAG specific indexing pipeline"""
def route(self, file_path: str | Path) -> IndexPipeline:
"""Simply disable the splitter (chunking) for this pipeline"""
pipeline = super().route(file_path)
pipeline.splitter = None
return pipeline
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
# create new graph_id and assign them to doc_id in self.Index
# record in the index
graph_id = str(uuid4())
with Session(engine) as session:
nodes = []
for file_id in file_ids:
if not file_id:
continue
nodes.append(
self.Index(
source_id=file_id,
target_id=graph_id,
relation_type="graph",
)
)
session.add_all(nodes)
session.commit()
return graph_id
def write_docs_to_files(self, graph_id: str, docs: list[Document]):
root_path, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
for doc in docs:
if doc.metadata.get("type", "text") == "text":
with open(input_path / f"{doc.doc_id}.txt", "w") as f:
f.write(doc.text)
return root_path
def call_graphrag_index(self, graph_id: str, all_docs: list[Document]):
if not check_graphrag_api_key():
raise ValueError(GRAPHRAG_KEY_MISSING_MESSAGE)
# call GraphRAG index with docs and graph_id
input_path = self.write_docs_to_files(graph_id, all_docs)
input_path = str(input_path.absolute())
# Construct the command
command = [
"python",
"-m",
"graphrag.index",
"--root",
input_path,
"--reporter",
"rich",
"--init",
]
# Run the command
yield Document(
channel="debug",
text="[GraphRAG] Creating index... This can take a long time.",
)
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
command = command[:-1]
# copy customized GraphRAG config file if it exists
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
setting_file_path = os.path.join(os.getcwd(), "settings.yaml.example")
destination_file_path = os.path.join(input_path, "settings.yaml")
try:
shutil.copy(setting_file_path, destination_file_path)
except shutil.Error:
# Handle the error if the file copy fails
print("failed to copy customized GraphRAG config file. ")
# Run the command and stream stdout
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
if process.stdout:
for line in process.stdout:
yield Document(channel="debug", text=line)
def stream(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
file_ids, errors, all_docs = yield from super().stream(
file_paths, reindex=reindex, **kwargs
)
# assign graph_id to file_ids
graph_id = self.store_file_id_with_graph_id(file_ids)
# call GraphRAG index with docs and graph_id
yield from self.call_graphrag_index(graph_id, all_docs)
return file_ids, errors, all_docs
class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
"""GraphRAG specific retriever pipeline"""
Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
@classmethod
def get_user_settings(cls) -> dict:
return {
"search_type": {
"name": "Search type",
"value": "local",
"choices": ["local", "global"],
"component": "dropdown",
"info": "Whether to use local or global search in the graph.",
}
}
def _build_graph_search(self):
assert (
len(self.file_ids) <= 1
), "GraphRAG retriever only supports one file_id at a time"
file_id = self.file_ids[0]
# retrieve the graph_id from the index
with Session(engine) as session:
graph_id = (
session.query(self.Index.target_id)
.filter(self.Index.source_id == file_id)
.filter(self.Index.relation_type == "graph")
.first()
)
graph_id = graph_id[0] if graph_id else None
assert graph_id, f"GraphRAG index not found for file_id: {file_id}"
root_path, _ = prepare_graph_index_path(graph_id)
output_path = root_path / "output"
INPUT_DIR = output_path
LANCEDB_URI = str(INPUT_DIR / "lancedb")
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
TEXT_UNIT_TABLE = "create_final_text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(
f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet"
)
entities = read_indexer_entities(
entity_df, entity_embedding_df, COMMUNITY_LEVEL
)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="entity_description_embeddings",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
if Path(LANCEDB_URI).is_dir():
rmtree(LANCEDB_URI)
_ = store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
print(f"Entity count: {len(entity_df)}")
# Read relationships
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
# Read community reports
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
# Read text units
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
# initialize default settings
embedding_model = os.getenv(
"GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small"
)
embedding_api_key = os.getenv("GRAPHRAG_API_KEY")
embedding_api_base = None
# use customized GraphRAG settings if the flag is set
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
settings_yaml_path = Path(root_path) / "settings.yaml"
with open(settings_yaml_path, "r") as f:
settings = yaml.safe_load(f)
if settings["embeddings"]["llm"]["model"]:
embedding_model = settings["embeddings"]["llm"]["model"]
if settings["embeddings"]["llm"]["api_key"]:
embedding_api_key = settings["embeddings"]["llm"]["api_key"]
if settings["embeddings"]["llm"]["api_base"]:
embedding_api_base = settings["embeddings"]["llm"]["api_base"]
text_embedder = OpenAIEmbedding(
api_key=embedding_api_key,
api_base=embedding_api_base,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
max_retries=20,
)
token_encoder = tiktoken.get_encoding("cl100k_base")
context_builder = LocalSearchMixedContext(
community_reports=reports,
text_units=text_units,
entities=entities,
relationships=relationships,
covariates=None,
entity_text_embeddings=description_embedding_store,
embedding_vectorstore_key=EntityVectorStoreKey.ID,
# if the vectorstore uses entity title as ids,
# set this to EntityVectorStoreKey.TITLE
text_embedder=text_embedder,
token_encoder=token_encoder,
)
return context_builder
def _to_document(self, header: str, context_text: str) -> RetrievedDocument:
return RetrievedDocument(
text=context_text,
metadata={
"file_name": header,
"type": "table",
"llm_trulens_score": 1.0,
},
score=1.0,
)
def format_context_records(self, context_records) -> list[RetrievedDocument]:
entities = context_records.get("entities", [])
relationships = context_records.get("relationships", [])
reports = context_records.get("reports", [])
sources = context_records.get("sources", [])
docs = []
context: str = ""
header = "<b>Entities</b>\n"
context = entities[["entity", "description"]].to_markdown(index=False)
docs.append(self._to_document(header, context))
header = "\n<b>Relationships</b>\n"
context = relationships[["source", "target", "description"]].to_markdown(
index=False
)
docs.append(self._to_document(header, context))
header = "\n<b>Reports</b>\n"
context = ""
for idx, row in reports.iterrows():
title, content = row["title"], row["content"]
context += f"\n\n<h5>Report <b>{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
header = "\n<b>Sources</b>\n"
context = ""
for idx, row in sources.iterrows():
title, content = row["id"], row["text"]
context += f"\n\n<h5>Source <b>#{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
return docs
def plot_graph(self, context_records):
relationships = context_records.get("relationships", [])
G = create_knowledge_graph(relationships)
plot = visualize_graph(G)
return plot
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
return documents
def run(
self,
text: str,
) -> list[RetrievedDocument]:
if not self.file_ids:
return []
if not check_graphrag_api_key():
raise ValueError(GRAPHRAG_KEY_MISSING_MESSAGE)
context_builder = self._build_graph_search()
local_context_params = {
"text_unit_prop": 0.5,
"community_prop": 0.1,
"conversation_history_max_turns": 5,
"conversation_history_user_turns_only": True,
"top_k_mapped_entities": 10,
"top_k_relationships": 10,
"include_entity_rank": False,
"include_relationship_weight": False,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
# set this to EntityVectorStoreKey.TITLE i
# f the vectorstore uses entity title as ids
"max_tokens": 12_000,
# change this based on the token limit you have on your model
# (if you are using a model with 8k limit, a good setting could be 5000)
}
context_text, context_records = context_builder.build_context(
query=text,
conversation_history=None,
**local_context_params,
)
documents = self.format_context_records(context_records)
plot = self.plot_graph(context_records)
return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]