from unittest.mock import patch import pytest from openai.types.chat.chat_completion import ChatCompletion from kotaemon.base import Document from kotaemon.indices.rankings import LLMReranking from kotaemon.llms import AzureChatOpenAI _openai_chat_completion_responses = [ ChatCompletion.parse_obj( { "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", "object": "chat.completion", "created": 1692338378, "model": "gpt-35-turbo", "system_fingerprint": None, "choices": [ { "index": 0, "finish_reason": "stop", "message": { "role": "assistant", "content": text, "function_call": None, "tool_calls": None, }, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, } ) for text in [ "YES", "NO", "YES", ] ] @pytest.fixture def llm(): return AzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", temperature=0, ) @patch( "openai.resources.chat.completions.Completions.create", side_effect=_openai_chat_completion_responses, ) def test_reranking(openai_completion, llm): documents = [Document(text=f"test {idx}") for idx in range(3)] query = "test query" reranker = LLMReranking(llm=llm, concurrent=False) rerank_docs = reranker(documents, query=query) assert len(rerank_docs) == 2