fix: selecting search all does not work on LightRAG / NanoGraphRAG (#627) #none

* fix: base path

* fix: select all doesn't work

* fix: adding new documents should update the existing index within the file collection instead of creating new one #561

* fix linter issues

* feat: update NanoGraphRAG with global collection search

---------

Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
Varun Sharma 2025-02-14 15:13:39 +01:00 committed by GitHub
parent 0b090896fd
commit e3921f7704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 173 additions and 29 deletions

View File

@ -317,6 +317,7 @@ SETTINGS_REASONING = {
},
}
USE_GLOBAL_GRAPHRAG = config("USE_GLOBAL_GRAPHRAG", default=True, cast=bool)
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool)
USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool)

View File

@ -25,7 +25,7 @@ class GraphRAGIndex(FileIndex):
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
file_ids = self._selector_ui.get_selected_ids(selected)
retrievers = [
GraphRAGRetrieverPipeline(
file_ids=file_ids,

View File

@ -1,4 +1,8 @@
from typing import Any
from typing import Any, Optional
from uuid import uuid4
from ktem.db.engine import engine
from sqlalchemy.orm import Session
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
@ -6,12 +10,35 @@ from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipel
class LightRAGIndex(GraphRAGIndex):
def __init__(self, app, id: int, name: str, config: dict):
super().__init__(app, id, name, config)
self._collection_graph_id: Optional[str] = None
def _setup_indexing_cls(self):
self._indexing_pipeline_cls = LightRAGIndexingPipeline
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]
def _get_or_create_collection_graph_id(self):
if self._collection_graph_id:
return self._collection_graph_id
# Try to find existing graph ID for this collection
with Session(engine) as session:
result = (
session.query(self._resources["Index"].target_id) # type: ignore
.filter(
self._resources["Index"].relation_type == "graph" # type: ignore
)
.first()
)
if result:
self._collection_graph_id = result[0]
else:
self._collection_graph_id = str(uuid4())
return self._collection_graph_id
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
@ -23,12 +50,14 @@ class LightRAGIndex(GraphRAGIndex):
}
# set the prompts
pipeline.prompts = striped_settings
# set collection graph id
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
return pipeline
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
file_ids = self._selector_ui.get_selected_ids(selected)
# retrieval settings
prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local")

View File

@ -242,6 +242,40 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""
prompts: dict[str, str] = {}
collection_graph_id: str
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
if not settings.USE_GLOBAL_GRAPHRAG:
return super().store_file_id_with_graph_id(file_ids)
# Use the collection-wide graph ID for LightRAG
graph_id = self.collection_graph_id
# Record all files under this graph_id
with Session(engine) as session:
for file_id in file_ids:
if not file_id:
continue
# Check if mapping already exists
existing = (
session.query(self.Index)
.filter(
self.Index.source_id == file_id,
self.Index.target_id == graph_id,
self.Index.relation_type == "graph",
)
.first()
)
if not existing:
node = self.Index(
source_id=file_id,
target_id=graph_id,
relation_type="graph",
)
session.add(node)
session.commit()
return graph_id
@classmethod
def get_user_settings(cls) -> dict:
@ -295,46 +329,54 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
yield Document(
channel="debug",
text="[GraphRAG] Creating index... This can take a long time.",
text="[GraphRAG] Creating/Updating index... This can take a long time.",
)
# remove all .json files in the input_path directory (previous cache)
# Check if graph already exists
graph_file = input_path / "graph_chunk_entity_relation.graphml"
is_incremental = graph_file.exists()
# Only clear cache if it's a new graph
if not is_incremental:
json_files = glob.glob(f"{input_path}/*.json")
for json_file in json_files:
os.remove(json_file)
# indexing
# Initialize or load existing GraphRAG
graphrag_func = build_graphrag(
input_path,
llm_func=llm_func,
embedding_func=embedding_func,
)
# output must be contain: Loaded graph from
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
total_docs = len(all_docs)
process_doc_count = 0
yield Document(
channel="debug",
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
text=(
f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: "
f"{process_doc_count} / {total_docs} documents."
),
)
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
combined_doc = "\n".join(cur_docs)
# Use insert for incremental updates
graphrag_func.insert(combined_doc)
process_doc_count += len(cur_docs)
yield Document(
channel="debug",
text=(
f"[GraphRAG] Indexed {process_doc_count} "
f"/ {total_docs} documents."
f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
f"{process_doc_count} / {total_docs} documents."
),
)
yield Document(
channel="debug",
text="[GraphRAG] Indexing finished.",
text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
)
def stream(

View File

@ -1,4 +1,8 @@
from typing import Any
from typing import Any, Optional
from uuid import uuid4
from ktem.db.engine import engine
from sqlalchemy.orm import Session
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
@ -6,12 +10,35 @@ from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverP
class NanoGraphRAGIndex(GraphRAGIndex):
def __init__(self, app, id: int, name: str, config: dict):
super().__init__(app, id, name, config)
self._collection_graph_id: Optional[str] = None
def _setup_indexing_cls(self):
self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
def _get_or_create_collection_graph_id(self):
if self._collection_graph_id:
return self._collection_graph_id
# Try to find existing graph ID for this collection
with Session(engine) as session:
result = (
session.query(self._resources["Index"].target_id) # type: ignore
.filter(
self._resources["Index"].relation_type == "graph" # type: ignore
)
.first()
)
if result:
self._collection_graph_id = result[0]
else:
self._collection_graph_id = str(uuid4())
return self._collection_graph_id
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
@ -23,12 +50,14 @@ class NanoGraphRAGIndex(GraphRAGIndex):
}
# set the prompts
pipeline.prompts = striped_settings
# set collection graph id
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
return pipeline
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
file_ids = self._selector_ui.get_selected_ids(selected)
# retrieval settings
prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local")

View File

@ -238,6 +238,40 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""
prompts: dict[str, str] = {}
collection_graph_id: str
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
if not settings.USE_GLOBAL_GRAPHRAG:
return super().store_file_id_with_graph_id(file_ids)
# Use the collection-wide graph ID for LightRAG
graph_id = self.collection_graph_id
# Record all files under this graph_id
with Session(engine) as session:
for file_id in file_ids:
if not file_id:
continue
# Check if mapping already exists
existing = (
session.query(self.Index)
.filter(
self.Index.source_id == file_id,
self.Index.target_id == graph_id,
self.Index.relation_type == "graph",
)
.first()
)
if not existing:
node = self.Index(
source_id=file_id,
target_id=graph_id,
relation_type="graph",
)
session.add(node)
session.commit()
return graph_id
@classmethod
def get_user_settings(cls) -> dict:
@ -291,45 +325,54 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
yield Document(
channel="debug",
text="[GraphRAG] Creating index... This can take a long time.",
text="[GraphRAG] Creating/Updating index... This can take a long time.",
)
# remove all .json files in the input_path directory (previous cache)
# Check if graph already exists
graph_file = input_path / "graph_chunk_entity_relation.graphml"
is_incremental = graph_file.exists()
# Only clear cache if it's a new graph
if not is_incremental:
json_files = glob.glob(f"{input_path}/*.json")
for json_file in json_files:
os.remove(json_file)
# indexing
# Initialize or load existing GraphRAG
graphrag_func = build_graphrag(
input_path,
llm_func=llm_func,
embedding_func=embedding_func,
)
# output must be contain: Loaded graph from
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
total_docs = len(all_docs)
process_doc_count = 0
yield Document(
channel="debug",
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
text=(
f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: "
f"{process_doc_count} / {total_docs} documents."
),
)
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
combined_doc = "\n".join(cur_docs)
# Use insert for incremental updates
graphrag_func.insert(combined_doc)
process_doc_count += len(cur_docs)
yield Document(
channel="debug",
text=(
f"[GraphRAG] Indexed {process_doc_count} "
f"/ {total_docs} documents."
f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
f"{process_doc_count} / {total_docs} documents."
),
)
yield Document(
channel="debug",
text="[GraphRAG] Indexing finished.",
text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
)
def stream(