fix: selecting search all does not work on LightRAG / NanoGraphRAG (#627) #none
* fix: base path * fix: select all doesn't work * fix: adding new documents should update the existing index within the file collection instead of creating new one #561 * fix linter issues * feat: update NanoGraphRAG with global collection search --------- Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
parent
0b090896fd
commit
e3921f7704
|
@ -317,6 +317,7 @@ SETTINGS_REASONING = {
|
|||
},
|
||||
}
|
||||
|
||||
USE_GLOBAL_GRAPHRAG = config("USE_GLOBAL_GRAPHRAG", default=True, cast=bool)
|
||||
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
|
||||
USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool)
|
||||
USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool)
|
||||
|
|
|
@ -25,7 +25,7 @@ class GraphRAGIndex(FileIndex):
|
|||
def get_retriever_pipelines(
|
||||
self, settings: dict, user_id: int, selected: Any = None
|
||||
) -> list["BaseFileIndexRetriever"]:
|
||||
_, file_ids, _ = selected
|
||||
file_ids = self._selector_ui.get_selected_ids(selected)
|
||||
retrievers = [
|
||||
GraphRAGRetrieverPipeline(
|
||||
file_ids=file_ids,
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from ktem.db.engine import engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
from .graph_index import GraphRAGIndex
|
||||
|
@ -6,12 +10,35 @@ from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipel
|
|||
|
||||
|
||||
class LightRAGIndex(GraphRAGIndex):
|
||||
def __init__(self, app, id: int, name: str, config: dict):
|
||||
super().__init__(app, id, name, config)
|
||||
self._collection_graph_id: Optional[str] = None
|
||||
|
||||
def _setup_indexing_cls(self):
|
||||
self._indexing_pipeline_cls = LightRAGIndexingPipeline
|
||||
|
||||
def _setup_retriever_cls(self):
|
||||
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]
|
||||
|
||||
def _get_or_create_collection_graph_id(self):
|
||||
if self._collection_graph_id:
|
||||
return self._collection_graph_id
|
||||
|
||||
# Try to find existing graph ID for this collection
|
||||
with Session(engine) as session:
|
||||
result = (
|
||||
session.query(self._resources["Index"].target_id) # type: ignore
|
||||
.filter(
|
||||
self._resources["Index"].relation_type == "graph" # type: ignore
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if result:
|
||||
self._collection_graph_id = result[0]
|
||||
else:
|
||||
self._collection_graph_id = str(uuid4())
|
||||
return self._collection_graph_id
|
||||
|
||||
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||
pipeline = super().get_indexing_pipeline(settings, user_id)
|
||||
# indexing settings
|
||||
|
@ -23,12 +50,14 @@ class LightRAGIndex(GraphRAGIndex):
|
|||
}
|
||||
# set the prompts
|
||||
pipeline.prompts = striped_settings
|
||||
# set collection graph id
|
||||
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
|
||||
return pipeline
|
||||
|
||||
def get_retriever_pipelines(
|
||||
self, settings: dict, user_id: int, selected: Any = None
|
||||
) -> list["BaseFileIndexRetriever"]:
|
||||
_, file_ids, _ = selected
|
||||
file_ids = self._selector_ui.get_selected_ids(selected)
|
||||
# retrieval settings
|
||||
prefix = f"index.options.{self.id}."
|
||||
search_type = settings.get(prefix + "search_type", "local")
|
||||
|
|
|
@ -242,6 +242,40 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
|||
"""GraphRAG specific indexing pipeline"""
|
||||
|
||||
prompts: dict[str, str] = {}
|
||||
collection_graph_id: str
|
||||
|
||||
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
||||
if not settings.USE_GLOBAL_GRAPHRAG:
|
||||
return super().store_file_id_with_graph_id(file_ids)
|
||||
|
||||
# Use the collection-wide graph ID for LightRAG
|
||||
graph_id = self.collection_graph_id
|
||||
|
||||
# Record all files under this graph_id
|
||||
with Session(engine) as session:
|
||||
for file_id in file_ids:
|
||||
if not file_id:
|
||||
continue
|
||||
# Check if mapping already exists
|
||||
existing = (
|
||||
session.query(self.Index)
|
||||
.filter(
|
||||
self.Index.source_id == file_id,
|
||||
self.Index.target_id == graph_id,
|
||||
self.Index.relation_type == "graph",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not existing:
|
||||
node = self.Index(
|
||||
source_id=file_id,
|
||||
target_id=graph_id,
|
||||
relation_type="graph",
|
||||
)
|
||||
session.add(node)
|
||||
session.commit()
|
||||
|
||||
return graph_id
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
|
@ -295,46 +329,54 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
|||
|
||||
yield Document(
|
||||
channel="debug",
|
||||
text="[GraphRAG] Creating index... This can take a long time.",
|
||||
text="[GraphRAG] Creating/Updating index... This can take a long time.",
|
||||
)
|
||||
|
||||
# remove all .json files in the input_path directory (previous cache)
|
||||
json_files = glob.glob(f"{input_path}/*.json")
|
||||
for json_file in json_files:
|
||||
os.remove(json_file)
|
||||
# Check if graph already exists
|
||||
graph_file = input_path / "graph_chunk_entity_relation.graphml"
|
||||
is_incremental = graph_file.exists()
|
||||
|
||||
# indexing
|
||||
# Only clear cache if it's a new graph
|
||||
if not is_incremental:
|
||||
json_files = glob.glob(f"{input_path}/*.json")
|
||||
for json_file in json_files:
|
||||
os.remove(json_file)
|
||||
|
||||
# Initialize or load existing GraphRAG
|
||||
graphrag_func = build_graphrag(
|
||||
input_path,
|
||||
llm_func=llm_func,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
# output must be contain: Loaded graph from
|
||||
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
|
||||
|
||||
total_docs = len(all_docs)
|
||||
process_doc_count = 0
|
||||
yield Document(
|
||||
channel="debug",
|
||||
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
|
||||
text=(
|
||||
f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: "
|
||||
f"{process_doc_count} / {total_docs} documents."
|
||||
),
|
||||
)
|
||||
|
||||
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
||||
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
||||
combined_doc = "\n".join(cur_docs)
|
||||
|
||||
# Use insert for incremental updates
|
||||
graphrag_func.insert(combined_doc)
|
||||
process_doc_count += len(cur_docs)
|
||||
yield Document(
|
||||
channel="debug",
|
||||
text=(
|
||||
f"[GraphRAG] Indexed {process_doc_count} "
|
||||
f"/ {total_docs} documents."
|
||||
f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
|
||||
f"{process_doc_count} / {total_docs} documents."
|
||||
),
|
||||
)
|
||||
|
||||
yield Document(
|
||||
channel="debug",
|
||||
text="[GraphRAG] Indexing finished.",
|
||||
text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
|
||||
)
|
||||
|
||||
def stream(
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from ktem.db.engine import engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
from .graph_index import GraphRAGIndex
|
||||
|
@ -6,12 +10,35 @@ from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverP
|
|||
|
||||
|
||||
class NanoGraphRAGIndex(GraphRAGIndex):
|
||||
def __init__(self, app, id: int, name: str, config: dict):
|
||||
super().__init__(app, id, name, config)
|
||||
self._collection_graph_id: Optional[str] = None
|
||||
|
||||
def _setup_indexing_cls(self):
|
||||
self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline
|
||||
|
||||
def _setup_retriever_cls(self):
|
||||
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
|
||||
|
||||
def _get_or_create_collection_graph_id(self):
|
||||
if self._collection_graph_id:
|
||||
return self._collection_graph_id
|
||||
|
||||
# Try to find existing graph ID for this collection
|
||||
with Session(engine) as session:
|
||||
result = (
|
||||
session.query(self._resources["Index"].target_id) # type: ignore
|
||||
.filter(
|
||||
self._resources["Index"].relation_type == "graph" # type: ignore
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if result:
|
||||
self._collection_graph_id = result[0]
|
||||
else:
|
||||
self._collection_graph_id = str(uuid4())
|
||||
return self._collection_graph_id
|
||||
|
||||
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||
pipeline = super().get_indexing_pipeline(settings, user_id)
|
||||
# indexing settings
|
||||
|
@ -23,12 +50,14 @@ class NanoGraphRAGIndex(GraphRAGIndex):
|
|||
}
|
||||
# set the prompts
|
||||
pipeline.prompts = striped_settings
|
||||
# set collection graph id
|
||||
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
|
||||
return pipeline
|
||||
|
||||
def get_retriever_pipelines(
|
||||
self, settings: dict, user_id: int, selected: Any = None
|
||||
) -> list["BaseFileIndexRetriever"]:
|
||||
_, file_ids, _ = selected
|
||||
file_ids = self._selector_ui.get_selected_ids(selected)
|
||||
# retrieval settings
|
||||
prefix = f"index.options.{self.id}."
|
||||
search_type = settings.get(prefix + "search_type", "local")
|
||||
|
|
|
@ -238,6 +238,40 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
|||
"""GraphRAG specific indexing pipeline"""
|
||||
|
||||
prompts: dict[str, str] = {}
|
||||
collection_graph_id: str
|
||||
|
||||
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
||||
if not settings.USE_GLOBAL_GRAPHRAG:
|
||||
return super().store_file_id_with_graph_id(file_ids)
|
||||
|
||||
# Use the collection-wide graph ID for LightRAG
|
||||
graph_id = self.collection_graph_id
|
||||
|
||||
# Record all files under this graph_id
|
||||
with Session(engine) as session:
|
||||
for file_id in file_ids:
|
||||
if not file_id:
|
||||
continue
|
||||
# Check if mapping already exists
|
||||
existing = (
|
||||
session.query(self.Index)
|
||||
.filter(
|
||||
self.Index.source_id == file_id,
|
||||
self.Index.target_id == graph_id,
|
||||
self.Index.relation_type == "graph",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not existing:
|
||||
node = self.Index(
|
||||
source_id=file_id,
|
||||
target_id=graph_id,
|
||||
relation_type="graph",
|
||||
)
|
||||
session.add(node)
|
||||
session.commit()
|
||||
|
||||
return graph_id
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
|
@ -291,45 +325,54 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
|||
|
||||
yield Document(
|
||||
channel="debug",
|
||||
text="[GraphRAG] Creating index... This can take a long time.",
|
||||
text="[GraphRAG] Creating/Updating index... This can take a long time.",
|
||||
)
|
||||
|
||||
# remove all .json files in the input_path directory (previous cache)
|
||||
json_files = glob.glob(f"{input_path}/*.json")
|
||||
for json_file in json_files:
|
||||
os.remove(json_file)
|
||||
# Check if graph already exists
|
||||
graph_file = input_path / "graph_chunk_entity_relation.graphml"
|
||||
is_incremental = graph_file.exists()
|
||||
|
||||
# indexing
|
||||
# Only clear cache if it's a new graph
|
||||
if not is_incremental:
|
||||
json_files = glob.glob(f"{input_path}/*.json")
|
||||
for json_file in json_files:
|
||||
os.remove(json_file)
|
||||
|
||||
# Initialize or load existing GraphRAG
|
||||
graphrag_func = build_graphrag(
|
||||
input_path,
|
||||
llm_func=llm_func,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
# output must be contain: Loaded graph from
|
||||
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
|
||||
|
||||
total_docs = len(all_docs)
|
||||
process_doc_count = 0
|
||||
yield Document(
|
||||
channel="debug",
|
||||
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
|
||||
text=(
|
||||
f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: "
|
||||
f"{process_doc_count} / {total_docs} documents."
|
||||
),
|
||||
)
|
||||
|
||||
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
||||
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
||||
combined_doc = "\n".join(cur_docs)
|
||||
|
||||
# Use insert for incremental updates
|
||||
graphrag_func.insert(combined_doc)
|
||||
process_doc_count += len(cur_docs)
|
||||
yield Document(
|
||||
channel="debug",
|
||||
text=(
|
||||
f"[GraphRAG] Indexed {process_doc_count} "
|
||||
f"/ {total_docs} documents."
|
||||
f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
|
||||
f"{process_doc_count} / {total_docs} documents."
|
||||
),
|
||||
)
|
||||
|
||||
yield Document(
|
||||
channel="debug",
|
||||
text="[GraphRAG] Indexing finished.",
|
||||
text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
|
||||
)
|
||||
|
||||
def stream(
|
||||
|
|
Loading…
Reference in New Issue
Block a user