feat: add structured output to openai (#603) #none
* add structured output to openai * remove notebook, modify prepare output method * fix: comfort precommit --------- Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
parent
6f4acc979c
commit
9b05693e4f
|
@ -8,6 +8,7 @@ from .schema import (
|
|||
HumanMessage,
|
||||
LLMInterface,
|
||||
RetrievedDocument,
|
||||
StructuredOutputLLMInterface,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
@ -21,6 +22,7 @@ __all__ = [
|
|||
"HumanMessage",
|
||||
"RetrievedDocument",
|
||||
"LLMInterface",
|
||||
"StructuredOutputLLMInterface",
|
||||
"ExtractorOutput",
|
||||
"Param",
|
||||
"Node",
|
||||
|
|
|
@ -143,6 +143,11 @@ class LLMInterface(AIMessage):
|
|||
logprobs: list[float] = []
|
||||
|
||||
|
||||
class StructuredOutputLLMInterface(LLMInterface):
|
||||
parsed: Any
|
||||
refusal: str = ""
|
||||
|
||||
|
||||
class ExtractorOutput(Document):
|
||||
"""
|
||||
Represents the output of an extractor.
|
||||
|
|
|
@ -14,6 +14,7 @@ from .chats import (
|
|||
LCGeminiChat,
|
||||
LCOllamaChat,
|
||||
LlamaCppChat,
|
||||
StructuredOutputChatOpenAI,
|
||||
)
|
||||
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
||||
from .cot import ManualSequentialChainOfThought, Thought
|
||||
|
@ -31,6 +32,7 @@ __all__ = [
|
|||
"SystemMessage",
|
||||
"AzureChatOpenAI",
|
||||
"ChatOpenAI",
|
||||
"StructuredOutputChatOpenAI",
|
||||
"LCAnthropicChat",
|
||||
"LCGeminiChat",
|
||||
"LCCohereChat",
|
||||
|
|
|
@ -10,7 +10,7 @@ from .langchain_based import (
|
|||
LCOllamaChat,
|
||||
)
|
||||
from .llamacpp import LlamaCppChat
|
||||
from .openai import AzureChatOpenAI, ChatOpenAI
|
||||
from .openai import AzureChatOpenAI, ChatOpenAI, StructuredOutputChatOpenAI
|
||||
|
||||
__all__ = [
|
||||
"ChatOpenAI",
|
||||
|
@ -18,6 +18,7 @@ __all__ = [
|
|||
"ChatLLM",
|
||||
"EndpointChatLLM",
|
||||
"ChatOpenAI",
|
||||
"StructuredOutputChatOpenAI",
|
||||
"LCAnthropicChat",
|
||||
"LCGeminiChat",
|
||||
"LCCohereChat",
|
||||
|
|
|
@ -1,8 +1,16 @@
|
|||
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param
|
||||
from kotaemon.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
LLMInterface,
|
||||
Param,
|
||||
StructuredOutputLLMInterface,
|
||||
)
|
||||
|
||||
from .base import ChatLLM
|
||||
|
||||
|
@ -330,6 +338,88 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||
return await client.chat.completions.create(**params)
|
||||
|
||||
|
||||
class StructuredOutputChatOpenAI(ChatOpenAI):
|
||||
"""OpenAI chat model that returns structured output"""
|
||||
|
||||
response_schema: Type[BaseModel] = Param(
|
||||
help="class that subclasses pydantics BaseModel", required=True
|
||||
)
|
||||
|
||||
def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
|
||||
"""Convert the OpenAI response into StructuredOutputLLMInterface"""
|
||||
additional_kwargs = {}
|
||||
|
||||
if "tool_calls" in resp["choices"][0]["message"]:
|
||||
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
|
||||
"tool_calls"
|
||||
]
|
||||
|
||||
if resp["choices"][0].get("logprobs") is None:
|
||||
logprobs = []
|
||||
else:
|
||||
all_logprobs = resp["choices"][0]["logprobs"].get("content")
|
||||
logprobs = (
|
||||
[logprob["logprob"] for logprob in all_logprobs] if all_logprobs else []
|
||||
)
|
||||
|
||||
output = StructuredOutputLLMInterface(
|
||||
parsed=resp["choices"][0]["message"]["parsed"],
|
||||
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
|
||||
content=resp["choices"][0]["message"]["content"] or "",
|
||||
total_tokens=resp["usage"]["total_tokens"],
|
||||
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||
completion_tokens=resp["usage"]["completion_tokens"],
|
||||
messages=[
|
||||
AIMessage(content=(_["message"]["content"]) or "")
|
||||
for _ in resp["choices"]
|
||||
],
|
||||
additional_kwargs=additional_kwargs,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def prepare_params(self, **kwargs):
|
||||
if "tools_pydantic" in kwargs:
|
||||
kwargs.pop("tools_pydantic")
|
||||
|
||||
params_ = {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"tool_choice": self.tool_choice,
|
||||
"tools": self.tools,
|
||||
"logprobs": self.logprobs,
|
||||
"logit_bias": self.logit_bias,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"top_p": self.top_p,
|
||||
"response_format": self.response_schema,
|
||||
}
|
||||
params = {k: v for k, v in params_.items() if v is not None}
|
||||
params.update(kwargs)
|
||||
|
||||
# doesn't do streaming
|
||||
params.pop("stream")
|
||||
|
||||
return params
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params = self.prepare_params(**kwargs)
|
||||
|
||||
return client.beta.chat.completions.parse(**params)
|
||||
|
||||
async def aopenai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params = self.prepare_params(**kwargs)
|
||||
|
||||
return await client.beta.chat.completions.parse(**params)
|
||||
|
||||
|
||||
class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""OpenAI chat model provided by Microsoft Azure"""
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user