From a8b8fcea32a67412ba0de6ae516f36d1902462e2 Mon Sep 17 00:00:00 2001 From: eddprogrammer <119921381+eddprogrammer@users.noreply.github.com> Date: Tue, 17 Dec 2024 05:02:06 -0500 Subject: [PATCH] fix: add retry to lightrag llm_func call (#572) * add retry to lightrag llm_func call * fix: update logic for nanographrag --------- Co-authored-by: Song Lin Co-authored-by: Tadashi --- .../index/file/graph/lightrag_pipelines.py | 23 ++++++++++++++++++- .../ktem/index/file/graph/nano_pipelines.py | 23 ++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py b/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py index 6a374f4..ffc92b8 100644 --- a/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py +++ b/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py @@ -12,6 +12,12 @@ from ktem.db.models import engine from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.llms.manager import llms from sqlalchemy.orm import Session +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from theflow.settings import settings from kotaemon.base import Document, Param, RetrievedDocument @@ -49,6 +55,17 @@ INDEX_BATCHSIZE = 4 def get_llm_func(model): + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((Exception,)), + after=lambda retry_state: logging.warning( + f"LLM API call attempt {retry_state.attempt_number} failed. Retrying..." + ), + ) + async def _call_model(model, input_messages): + return (await model.ainvoke(input_messages)).text + async def llm_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -70,7 +87,11 @@ def get_llm_func(model): if if_cache_return is not None: return if_cache_return["return"] - output = (await model.ainvoke(input_messages)).text + try: + output = await _call_model(model, input_messages) + except Exception as e: + logging.error(f"Failed to call LLM API after 3 retries: {str(e)}") + raise print("-" * 50) print(output, "\n", "-" * 50) diff --git a/libs/ktem/ktem/index/file/graph/nano_pipelines.py b/libs/ktem/ktem/index/file/graph/nano_pipelines.py index bbfdf26..438877f 100644 --- a/libs/ktem/ktem/index/file/graph/nano_pipelines.py +++ b/libs/ktem/ktem/index/file/graph/nano_pipelines.py @@ -12,6 +12,12 @@ from ktem.db.models import engine from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.llms.manager import llms from sqlalchemy.orm import Session +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from theflow.settings import settings from kotaemon.base import Document, Param, RetrievedDocument @@ -50,6 +56,17 @@ INDEX_BATCHSIZE = 4 def get_llm_func(model): + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((Exception,)), + after=lambda retry_state: logging.warning( + f"LLM API call attempt {retry_state.attempt_number} failed. Retrying..." + ), + ) + async def _call_model(model, input_messages): + return (await model.ainvoke(input_messages)).text + async def llm_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -71,7 +88,11 @@ def get_llm_func(model): if if_cache_return is not None: return if_cache_return["return"] - output = (await model.ainvoke(input_messages)).text + try: + output = await _call_model(model, input_messages) + except Exception as e: + logging.error(f"Failed to call LLM API after 3 retries: {str(e)}") + raise print("-" * 50) print(output, "\n", "-" * 50)