feat: allow to use customized GraphRAG settings.yaml (#387) bump:patch
* allow to use customized GraphRAG settings.yaml * adjust import style * fix typo * Added GraphRAG original documentation reference. * feat: allow to use customized GraphRAG settings.yaml (#387) --------- Co-authored-by: Chen, Ron Gang <git@git.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
@@ -7,6 +8,8 @@ from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
import yaml
|
||||
from decouple import config
|
||||
from ktem.db.models import engine
|
||||
from sqlalchemy.orm import Session
|
||||
from theflow.settings import settings
|
||||
@@ -116,6 +119,16 @@ class GraphRAGIndexingPipeline(IndexDocumentPipeline):
|
||||
print(result.stdout)
|
||||
command = command[:-1]
|
||||
|
||||
# copy customized GraphRAG config file if it exists
|
||||
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
|
||||
setting_file_path = os.path.join(os.getcwd(), "settings.yaml.example")
|
||||
destination_file_path = os.path.join(input_path, "settings.yaml")
|
||||
try:
|
||||
shutil.copy(setting_file_path, destination_file_path)
|
||||
except shutil.Error:
|
||||
# Handle the error if the file copy fails
|
||||
print("failed to copy customized GraphRAG config file. ")
|
||||
|
||||
# Run the command and stream stdout
|
||||
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
|
||||
if process.stdout:
|
||||
@@ -221,12 +234,28 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
||||
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
|
||||
text_units = read_indexer_text_units(text_unit_df)
|
||||
|
||||
# initialize default settings
|
||||
embedding_model = os.getenv(
|
||||
"GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small"
|
||||
)
|
||||
embedding_api_key = os.getenv("GRAPHRAG_API_KEY")
|
||||
embedding_api_base = None
|
||||
|
||||
# use customized GraphRAG settings if the flag is set
|
||||
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
|
||||
settings_yaml_path = Path(root_path) / "settings.yaml"
|
||||
with open(settings_yaml_path, "r") as f:
|
||||
settings = yaml.safe_load(f)
|
||||
if settings["embeddings"]["llm"]["model"]:
|
||||
embedding_model = settings["embeddings"]["llm"]["model"]
|
||||
if settings["embeddings"]["llm"]["api_key"]:
|
||||
embedding_api_key = settings["embeddings"]["llm"]["api_key"]
|
||||
if settings["embeddings"]["llm"]["api_base"]:
|
||||
embedding_api_base = settings["embeddings"]["llm"]["api_base"]
|
||||
|
||||
text_embedder = OpenAIEmbedding(
|
||||
api_key=os.getenv("GRAPHRAG_API_KEY"),
|
||||
api_base=None,
|
||||
api_key=embedding_api_key,
|
||||
api_base=embedding_api_base,
|
||||
api_type=OpenaiApiType.OpenAI,
|
||||
model=embedding_model,
|
||||
deployment_name=embedding_model,
|
||||
|
Reference in New Issue
Block a user