From 9b05693e4f902023f12325767d16913565dda797 Mon Sep 17 00:00:00 2001 From: Ben Dykstra Date: Tue, 15 Apr 2025 01:54:23 -0600 Subject: [PATCH] feat: add structured output to openai (#603) #none * add structured output to openai * remove notebook, modify prepare output method * fix: comfort precommit --------- Co-authored-by: Tadashi --- libs/kotaemon/kotaemon/base/__init__.py | 2 + libs/kotaemon/kotaemon/base/schema.py | 5 + libs/kotaemon/kotaemon/llms/__init__.py | 2 + libs/kotaemon/kotaemon/llms/chats/__init__.py | 3 +- libs/kotaemon/kotaemon/llms/chats/openai.py | 94 ++++++++++++++++++- 5 files changed, 103 insertions(+), 3 deletions(-) diff --git a/libs/kotaemon/kotaemon/base/__init__.py b/libs/kotaemon/kotaemon/base/__init__.py index 52e036f..1f78710 100644 --- a/libs/kotaemon/kotaemon/base/__init__.py +++ b/libs/kotaemon/kotaemon/base/__init__.py @@ -8,6 +8,7 @@ from .schema import ( HumanMessage, LLMInterface, RetrievedDocument, + StructuredOutputLLMInterface, SystemMessage, ) @@ -21,6 +22,7 @@ __all__ = [ "HumanMessage", "RetrievedDocument", "LLMInterface", + "StructuredOutputLLMInterface", "ExtractorOutput", "Param", "Node", diff --git a/libs/kotaemon/kotaemon/base/schema.py b/libs/kotaemon/kotaemon/base/schema.py index ea26032..0a499f8 100644 --- a/libs/kotaemon/kotaemon/base/schema.py +++ b/libs/kotaemon/kotaemon/base/schema.py @@ -143,6 +143,11 @@ class LLMInterface(AIMessage): logprobs: list[float] = [] +class StructuredOutputLLMInterface(LLMInterface): + parsed: Any + refusal: str = "" + + class ExtractorOutput(Document): """ Represents the output of an extractor. diff --git a/libs/kotaemon/kotaemon/llms/__init__.py b/libs/kotaemon/kotaemon/llms/__init__.py index e7ddfbf..c70e2ff 100644 --- a/libs/kotaemon/kotaemon/llms/__init__.py +++ b/libs/kotaemon/kotaemon/llms/__init__.py @@ -14,6 +14,7 @@ from .chats import ( LCGeminiChat, LCOllamaChat, LlamaCppChat, + StructuredOutputChatOpenAI, ) from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI from .cot import ManualSequentialChainOfThought, Thought @@ -31,6 +32,7 @@ __all__ = [ "SystemMessage", "AzureChatOpenAI", "ChatOpenAI", + "StructuredOutputChatOpenAI", "LCAnthropicChat", "LCGeminiChat", "LCCohereChat", diff --git a/libs/kotaemon/kotaemon/llms/chats/__init__.py b/libs/kotaemon/kotaemon/llms/chats/__init__.py index 2581356..f3f8277 100644 --- a/libs/kotaemon/kotaemon/llms/chats/__init__.py +++ b/libs/kotaemon/kotaemon/llms/chats/__init__.py @@ -10,7 +10,7 @@ from .langchain_based import ( LCOllamaChat, ) from .llamacpp import LlamaCppChat -from .openai import AzureChatOpenAI, ChatOpenAI +from .openai import AzureChatOpenAI, ChatOpenAI, StructuredOutputChatOpenAI __all__ = [ "ChatOpenAI", @@ -18,6 +18,7 @@ __all__ = [ "ChatLLM", "EndpointChatLLM", "ChatOpenAI", + "StructuredOutputChatOpenAI", "LCAnthropicChat", "LCGeminiChat", "LCCohereChat", diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py index df46da2..a1f8854 100644 --- a/libs/kotaemon/kotaemon/llms/chats/openai.py +++ b/libs/kotaemon/kotaemon/llms/chats/openai.py @@ -1,8 +1,16 @@ -from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional +from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional, Type +from pydantic import BaseModel from theflow.utils.modules import import_dotted_string -from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param +from kotaemon.base import ( + AIMessage, + BaseMessage, + HumanMessage, + LLMInterface, + Param, + StructuredOutputLLMInterface, +) from .base import ChatLLM @@ -330,6 +338,88 @@ class ChatOpenAI(BaseChatOpenAI): return await client.chat.completions.create(**params) +class StructuredOutputChatOpenAI(ChatOpenAI): + """OpenAI chat model that returns structured output""" + + response_schema: Type[BaseModel] = Param( + help="class that subclasses pydantics BaseModel", required=True + ) + + def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface: + """Convert the OpenAI response into StructuredOutputLLMInterface""" + additional_kwargs = {} + + if "tool_calls" in resp["choices"][0]["message"]: + additional_kwargs["tool_calls"] = resp["choices"][0]["message"][ + "tool_calls" + ] + + if resp["choices"][0].get("logprobs") is None: + logprobs = [] + else: + all_logprobs = resp["choices"][0]["logprobs"].get("content") + logprobs = ( + [logprob["logprob"] for logprob in all_logprobs] if all_logprobs else [] + ) + + output = StructuredOutputLLMInterface( + parsed=resp["choices"][0]["message"]["parsed"], + candidates=[(_["message"]["content"] or "") for _ in resp["choices"]], + content=resp["choices"][0]["message"]["content"] or "", + total_tokens=resp["usage"]["total_tokens"], + prompt_tokens=resp["usage"]["prompt_tokens"], + completion_tokens=resp["usage"]["completion_tokens"], + messages=[ + AIMessage(content=(_["message"]["content"]) or "") + for _ in resp["choices"] + ], + additional_kwargs=additional_kwargs, + logprobs=logprobs, + ) + + return output + + def prepare_params(self, **kwargs): + if "tools_pydantic" in kwargs: + kwargs.pop("tools_pydantic") + + params_ = { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "n": self.n, + "stop": self.stop, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "tool_choice": self.tool_choice, + "tools": self.tools, + "logprobs": self.logprobs, + "logit_bias": self.logit_bias, + "top_logprobs": self.top_logprobs, + "top_p": self.top_p, + "response_format": self.response_schema, + } + params = {k: v for k, v in params_.items() if v is not None} + params.update(kwargs) + + # doesn't do streaming + params.pop("stream") + + return params + + def openai_response(self, client, **kwargs): + """Get the openai response""" + params = self.prepare_params(**kwargs) + + return client.beta.chat.completions.parse(**params) + + async def aopenai_response(self, client, **kwargs): + """Get the openai response""" + params = self.prepare_params(**kwargs) + + return await client.beta.chat.completions.parse(**params) + + class AzureChatOpenAI(BaseChatOpenAI): """OpenAI chat model provided by Microsoft Azure"""