diff --git a/flowsettings.py b/flowsettings.py index 6abeea2..350d16d 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -317,6 +317,7 @@ SETTINGS_REASONING = { }, } +USE_GLOBAL_GRAPHRAG = config("USE_GLOBAL_GRAPHRAG", default=True, cast=bool) USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool) USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool) USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool) diff --git a/libs/ktem/ktem/index/file/graph/graph_index.py b/libs/ktem/ktem/index/file/graph/graph_index.py index 797fd3e..cf5dee2 100644 --- a/libs/ktem/ktem/index/file/graph/graph_index.py +++ b/libs/ktem/ktem/index/file/graph/graph_index.py @@ -25,7 +25,7 @@ class GraphRAGIndex(FileIndex): def get_retriever_pipelines( self, settings: dict, user_id: int, selected: Any = None ) -> list["BaseFileIndexRetriever"]: - _, file_ids, _ = selected + file_ids = self._selector_ui.get_selected_ids(selected) retrievers = [ GraphRAGRetrieverPipeline( file_ids=file_ids, diff --git a/libs/ktem/ktem/index/file/graph/light_graph_index.py b/libs/ktem/ktem/index/file/graph/light_graph_index.py index 0238ff8..aae864e 100644 --- a/libs/ktem/ktem/index/file/graph/light_graph_index.py +++ b/libs/ktem/ktem/index/file/graph/light_graph_index.py @@ -1,4 +1,8 @@ -from typing import Any +from typing import Any, Optional +from uuid import uuid4 + +from ktem.db.engine import engine +from sqlalchemy.orm import Session from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever from .graph_index import GraphRAGIndex @@ -6,12 +10,35 @@ from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipel class LightRAGIndex(GraphRAGIndex): + def __init__(self, app, id: int, name: str, config: dict): + super().__init__(app, id, name, config) + self._collection_graph_id: Optional[str] = None + def _setup_indexing_cls(self): self._indexing_pipeline_cls = LightRAGIndexingPipeline def _setup_retriever_cls(self): self._retriever_pipeline_cls = [LightRAGRetrieverPipeline] + def _get_or_create_collection_graph_id(self): + if self._collection_graph_id: + return self._collection_graph_id + + # Try to find existing graph ID for this collection + with Session(engine) as session: + result = ( + session.query(self._resources["Index"].target_id) # type: ignore + .filter( + self._resources["Index"].relation_type == "graph" # type: ignore + ) + .first() + ) + if result: + self._collection_graph_id = result[0] + else: + self._collection_graph_id = str(uuid4()) + return self._collection_graph_id + def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing: pipeline = super().get_indexing_pipeline(settings, user_id) # indexing settings @@ -23,12 +50,14 @@ class LightRAGIndex(GraphRAGIndex): } # set the prompts pipeline.prompts = striped_settings + # set collection graph id + pipeline.collection_graph_id = self._get_or_create_collection_graph_id() return pipeline def get_retriever_pipelines( self, settings: dict, user_id: int, selected: Any = None ) -> list["BaseFileIndexRetriever"]: - _, file_ids, _ = selected + file_ids = self._selector_ui.get_selected_ids(selected) # retrieval settings prefix = f"index.options.{self.id}." search_type = settings.get(prefix + "search_type", "local") diff --git a/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py b/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py index 7fe1ad9..02dbb90 100644 --- a/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py +++ b/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py @@ -242,6 +242,40 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline): """GraphRAG specific indexing pipeline""" prompts: dict[str, str] = {} + collection_graph_id: str + + def store_file_id_with_graph_id(self, file_ids: list[str | None]): + if not settings.USE_GLOBAL_GRAPHRAG: + return super().store_file_id_with_graph_id(file_ids) + + # Use the collection-wide graph ID for LightRAG + graph_id = self.collection_graph_id + + # Record all files under this graph_id + with Session(engine) as session: + for file_id in file_ids: + if not file_id: + continue + # Check if mapping already exists + existing = ( + session.query(self.Index) + .filter( + self.Index.source_id == file_id, + self.Index.target_id == graph_id, + self.Index.relation_type == "graph", + ) + .first() + ) + if not existing: + node = self.Index( + source_id=file_id, + target_id=graph_id, + relation_type="graph", + ) + session.add(node) + session.commit() + + return graph_id @classmethod def get_user_settings(cls) -> dict: @@ -295,46 +329,54 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline): yield Document( channel="debug", - text="[GraphRAG] Creating index... This can take a long time.", + text="[GraphRAG] Creating/Updating index... This can take a long time.", ) - # remove all .json files in the input_path directory (previous cache) - json_files = glob.glob(f"{input_path}/*.json") - for json_file in json_files: - os.remove(json_file) + # Check if graph already exists + graph_file = input_path / "graph_chunk_entity_relation.graphml" + is_incremental = graph_file.exists() - # indexing + # Only clear cache if it's a new graph + if not is_incremental: + json_files = glob.glob(f"{input_path}/*.json") + for json_file in json_files: + os.remove(json_file) + + # Initialize or load existing GraphRAG graphrag_func = build_graphrag( input_path, llm_func=llm_func, embedding_func=embedding_func, ) - # output must be contain: Loaded graph from - # ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges + total_docs = len(all_docs) process_doc_count = 0 yield Document( channel="debug", - text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.", + text=( + f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: " + f"{process_doc_count} / {total_docs} documents." + ), ) for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE): cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE] combined_doc = "\n".join(cur_docs) + # Use insert for incremental updates graphrag_func.insert(combined_doc) process_doc_count += len(cur_docs) yield Document( channel="debug", text=( - f"[GraphRAG] Indexed {process_doc_count} " - f"/ {total_docs} documents." + f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} " + f"{process_doc_count} / {total_docs} documents." ), ) yield Document( channel="debug", - text="[GraphRAG] Indexing finished.", + text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.", ) def stream( diff --git a/libs/ktem/ktem/index/file/graph/nano_graph_index.py b/libs/ktem/ktem/index/file/graph/nano_graph_index.py index 064c460..6fc70a4 100644 --- a/libs/ktem/ktem/index/file/graph/nano_graph_index.py +++ b/libs/ktem/ktem/index/file/graph/nano_graph_index.py @@ -1,4 +1,8 @@ -from typing import Any +from typing import Any, Optional +from uuid import uuid4 + +from ktem.db.engine import engine +from sqlalchemy.orm import Session from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever from .graph_index import GraphRAGIndex @@ -6,12 +10,35 @@ from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverP class NanoGraphRAGIndex(GraphRAGIndex): + def __init__(self, app, id: int, name: str, config: dict): + super().__init__(app, id, name, config) + self._collection_graph_id: Optional[str] = None + def _setup_indexing_cls(self): self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline def _setup_retriever_cls(self): self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline] + def _get_or_create_collection_graph_id(self): + if self._collection_graph_id: + return self._collection_graph_id + + # Try to find existing graph ID for this collection + with Session(engine) as session: + result = ( + session.query(self._resources["Index"].target_id) # type: ignore + .filter( + self._resources["Index"].relation_type == "graph" # type: ignore + ) + .first() + ) + if result: + self._collection_graph_id = result[0] + else: + self._collection_graph_id = str(uuid4()) + return self._collection_graph_id + def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing: pipeline = super().get_indexing_pipeline(settings, user_id) # indexing settings @@ -23,12 +50,14 @@ class NanoGraphRAGIndex(GraphRAGIndex): } # set the prompts pipeline.prompts = striped_settings + # set collection graph id + pipeline.collection_graph_id = self._get_or_create_collection_graph_id() return pipeline def get_retriever_pipelines( self, settings: dict, user_id: int, selected: Any = None ) -> list["BaseFileIndexRetriever"]: - _, file_ids, _ = selected + file_ids = self._selector_ui.get_selected_ids(selected) # retrieval settings prefix = f"index.options.{self.id}." search_type = settings.get(prefix + "search_type", "local") diff --git a/libs/ktem/ktem/index/file/graph/nano_pipelines.py b/libs/ktem/ktem/index/file/graph/nano_pipelines.py index 0a4b4ab..84f9d0d 100644 --- a/libs/ktem/ktem/index/file/graph/nano_pipelines.py +++ b/libs/ktem/ktem/index/file/graph/nano_pipelines.py @@ -238,6 +238,40 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline): """GraphRAG specific indexing pipeline""" prompts: dict[str, str] = {} + collection_graph_id: str + + def store_file_id_with_graph_id(self, file_ids: list[str | None]): + if not settings.USE_GLOBAL_GRAPHRAG: + return super().store_file_id_with_graph_id(file_ids) + + # Use the collection-wide graph ID for LightRAG + graph_id = self.collection_graph_id + + # Record all files under this graph_id + with Session(engine) as session: + for file_id in file_ids: + if not file_id: + continue + # Check if mapping already exists + existing = ( + session.query(self.Index) + .filter( + self.Index.source_id == file_id, + self.Index.target_id == graph_id, + self.Index.relation_type == "graph", + ) + .first() + ) + if not existing: + node = self.Index( + source_id=file_id, + target_id=graph_id, + relation_type="graph", + ) + session.add(node) + session.commit() + + return graph_id @classmethod def get_user_settings(cls) -> dict: @@ -291,45 +325,54 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline): yield Document( channel="debug", - text="[GraphRAG] Creating index... This can take a long time.", + text="[GraphRAG] Creating/Updating index... This can take a long time.", ) - # remove all .json files in the input_path directory (previous cache) - json_files = glob.glob(f"{input_path}/*.json") - for json_file in json_files: - os.remove(json_file) + # Check if graph already exists + graph_file = input_path / "graph_chunk_entity_relation.graphml" + is_incremental = graph_file.exists() - # indexing + # Only clear cache if it's a new graph + if not is_incremental: + json_files = glob.glob(f"{input_path}/*.json") + for json_file in json_files: + os.remove(json_file) + + # Initialize or load existing GraphRAG graphrag_func = build_graphrag( input_path, llm_func=llm_func, embedding_func=embedding_func, ) - # output must be contain: Loaded graph from - # ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges + total_docs = len(all_docs) process_doc_count = 0 yield Document( channel="debug", - text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.", + text=( + f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: " + f"{process_doc_count} / {total_docs} documents." + ), ) + for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE): cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE] combined_doc = "\n".join(cur_docs) + # Use insert for incremental updates graphrag_func.insert(combined_doc) process_doc_count += len(cur_docs) yield Document( channel="debug", text=( - f"[GraphRAG] Indexed {process_doc_count} " - f"/ {total_docs} documents." + f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} " + f"{process_doc_count} / {total_docs} documents." ), ) yield Document( channel="debug", - text="[GraphRAG] Indexing finished.", + text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.", ) def stream(