fix: add retry to lightrag llm_func call (#572)

* add retry to lightrag llm_func call

* fix: update logic for nanographrag

---------

Co-authored-by: Song Lin <song.lin@kirkland.com>
Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
eddprogrammer 2024-12-17 05:02:06 -05:00 committed by GitHub
parent 1d3c4f4433
commit a8b8fcea32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 2 deletions

View File

@ -12,6 +12,12 @@ from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms from ktem.llms.manager import llms
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from theflow.settings import settings from theflow.settings import settings
from kotaemon.base import Document, Param, RetrievedDocument from kotaemon.base import Document, Param, RetrievedDocument
@ -49,6 +55,17 @@ INDEX_BATCHSIZE = 4
def get_llm_func(model): def get_llm_func(model):
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((Exception,)),
after=lambda retry_state: logging.warning(
f"LLM API call attempt {retry_state.attempt_number} failed. Retrying..."
),
)
async def _call_model(model, input_messages):
return (await model.ainvoke(input_messages)).text
async def llm_func( async def llm_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@ -70,7 +87,11 @@ def get_llm_func(model):
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
output = (await model.ainvoke(input_messages)).text try:
output = await _call_model(model, input_messages)
except Exception as e:
logging.error(f"Failed to call LLM API after 3 retries: {str(e)}")
raise
print("-" * 50) print("-" * 50)
print(output, "\n", "-" * 50) print(output, "\n", "-" * 50)

View File

@ -12,6 +12,12 @@ from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms from ktem.llms.manager import llms
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from theflow.settings import settings from theflow.settings import settings
from kotaemon.base import Document, Param, RetrievedDocument from kotaemon.base import Document, Param, RetrievedDocument
@ -50,6 +56,17 @@ INDEX_BATCHSIZE = 4
def get_llm_func(model): def get_llm_func(model):
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((Exception,)),
after=lambda retry_state: logging.warning(
f"LLM API call attempt {retry_state.attempt_number} failed. Retrying..."
),
)
async def _call_model(model, input_messages):
return (await model.ainvoke(input_messages)).text
async def llm_func( async def llm_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@ -71,7 +88,11 @@ def get_llm_func(model):
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
output = (await model.ainvoke(input_messages)).text try:
output = await _call_model(model, input_messages)
except Exception as e:
logging.error(f"Failed to call LLM API after 3 retries: {str(e)}")
raise
print("-" * 50) print("-" * 50)
print(output, "\n", "-" * 50) print(output, "\n", "-" * 50)