feat: add structured output to openai (#603) #none

* add structured output to openai

* remove notebook, modify prepare output method

* fix: comfort precommit

---------

Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
Ben Dykstra 2025-04-15 01:54:23 -06:00 committed by GitHub
parent 6f4acc979c
commit 9b05693e4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 103 additions and 3 deletions

View File

@ -8,6 +8,7 @@ from .schema import (
HumanMessage,
LLMInterface,
RetrievedDocument,
StructuredOutputLLMInterface,
SystemMessage,
)
@ -21,6 +22,7 @@ __all__ = [
"HumanMessage",
"RetrievedDocument",
"LLMInterface",
"StructuredOutputLLMInterface",
"ExtractorOutput",
"Param",
"Node",

View File

@ -143,6 +143,11 @@ class LLMInterface(AIMessage):
logprobs: list[float] = []
class StructuredOutputLLMInterface(LLMInterface):
parsed: Any
refusal: str = ""
class ExtractorOutput(Document):
"""
Represents the output of an extractor.

View File

@ -14,6 +14,7 @@ from .chats import (
LCGeminiChat,
LCOllamaChat,
LlamaCppChat,
StructuredOutputChatOpenAI,
)
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
from .cot import ManualSequentialChainOfThought, Thought
@ -31,6 +32,7 @@ __all__ = [
"SystemMessage",
"AzureChatOpenAI",
"ChatOpenAI",
"StructuredOutputChatOpenAI",
"LCAnthropicChat",
"LCGeminiChat",
"LCCohereChat",

View File

@ -10,7 +10,7 @@ from .langchain_based import (
LCOllamaChat,
)
from .llamacpp import LlamaCppChat
from .openai import AzureChatOpenAI, ChatOpenAI
from .openai import AzureChatOpenAI, ChatOpenAI, StructuredOutputChatOpenAI
__all__ = [
"ChatOpenAI",
@ -18,6 +18,7 @@ __all__ = [
"ChatLLM",
"EndpointChatLLM",
"ChatOpenAI",
"StructuredOutputChatOpenAI",
"LCAnthropicChat",
"LCGeminiChat",
"LCCohereChat",

View File

@ -1,8 +1,16 @@
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional, Type
from pydantic import BaseModel
from theflow.utils.modules import import_dotted_string
from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param
from kotaemon.base import (
AIMessage,
BaseMessage,
HumanMessage,
LLMInterface,
Param,
StructuredOutputLLMInterface,
)
from .base import ChatLLM
@ -330,6 +338,88 @@ class ChatOpenAI(BaseChatOpenAI):
return await client.chat.completions.create(**params)
class StructuredOutputChatOpenAI(ChatOpenAI):
"""OpenAI chat model that returns structured output"""
response_schema: Type[BaseModel] = Param(
help="class that subclasses pydantics BaseModel", required=True
)
def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
"""Convert the OpenAI response into StructuredOutputLLMInterface"""
additional_kwargs = {}
if "tool_calls" in resp["choices"][0]["message"]:
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
"tool_calls"
]
if resp["choices"][0].get("logprobs") is None:
logprobs = []
else:
all_logprobs = resp["choices"][0]["logprobs"].get("content")
logprobs = (
[logprob["logprob"] for logprob in all_logprobs] if all_logprobs else []
)
output = StructuredOutputLLMInterface(
parsed=resp["choices"][0]["message"]["parsed"],
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"] or "",
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
messages=[
AIMessage(content=(_["message"]["content"]) or "")
for _ in resp["choices"]
],
additional_kwargs=additional_kwargs,
logprobs=logprobs,
)
return output
def prepare_params(self, **kwargs):
if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic")
params_ = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": self.n,
"stop": self.stop,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"tool_choice": self.tool_choice,
"tools": self.tools,
"logprobs": self.logprobs,
"logit_bias": self.logit_bias,
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
"response_format": self.response_schema,
}
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs)
# doesn't do streaming
params.pop("stream")
return params
def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = self.prepare_params(**kwargs)
return client.beta.chat.completions.parse(**params)
async def aopenai_response(self, client, **kwargs):
"""Get the openai response"""
params = self.prepare_params(**kwargs)
return await client.beta.chat.completions.parse(**params)
class AzureChatOpenAI(BaseChatOpenAI):
"""OpenAI chat model provided by Microsoft Azure"""