Update the Citation pipeline according to new OpenAI function call interface (#40)
This commit is contained in:
parent
1b2082a140
commit
c6045bcb9f
|
@ -75,8 +75,8 @@ class CitationPipeline(BaseComponent):
|
||||||
"parameters": schema,
|
"parameters": schema,
|
||||||
}
|
}
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
"functions": [function],
|
"tools": [{"type": "function", "function": function}],
|
||||||
"function_call": {"name": function["name"]},
|
"tool_choice": "auto",
|
||||||
}
|
}
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
|
@ -99,14 +99,13 @@ class CitationPipeline(BaseComponent):
|
||||||
|
|
||||||
def invoke(self, context: str, question: str):
|
def invoke(self, context: str, question: str):
|
||||||
messages, llm_kwargs = self.prepare_llm(context, question)
|
messages, llm_kwargs = self.prepare_llm(context, question)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("CitationPipeline: invoking LLM")
|
print("CitationPipeline: invoking LLM")
|
||||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||||
print("CitationPipeline: finish invoking LLM")
|
print("CitationPipeline: finish invoking LLM")
|
||||||
if not llm_output.messages:
|
if not llm_output.messages:
|
||||||
return None
|
return None
|
||||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||||
"arguments"
|
"arguments"
|
||||||
]
|
]
|
||||||
output = QuestionAnswer.parse_raw(function_output)
|
output = QuestionAnswer.parse_raw(function_output)
|
||||||
|
@ -123,16 +122,12 @@ class CitationPipeline(BaseComponent):
|
||||||
print("CitationPipeline: async invoking LLM")
|
print("CitationPipeline: async invoking LLM")
|
||||||
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
|
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
|
||||||
print("CitationPipeline: finish async invoking LLM")
|
print("CitationPipeline: finish async invoking LLM")
|
||||||
|
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||||
|
"arguments"
|
||||||
|
]
|
||||||
|
output = QuestionAnswer.parse_raw(function_output)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not llm_output.messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
|
||||||
"arguments"
|
|
||||||
]
|
|
||||||
output = QuestionAnswer.parse_raw(function_output)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -152,6 +152,28 @@ class BaseChatOpenAI(ChatLLM):
|
||||||
|
|
||||||
return output_
|
return output_
|
||||||
|
|
||||||
|
def prepare_output(self, resp: dict) -> LLMInterface:
|
||||||
|
"""Convert the OpenAI response into LLMInterface"""
|
||||||
|
additional_kwargs = {}
|
||||||
|
if "tool_calls" in resp["choices"][0]["message"]:
|
||||||
|
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
|
||||||
|
"tool_calls"
|
||||||
|
]
|
||||||
|
output = LLMInterface(
|
||||||
|
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
|
||||||
|
content=resp["choices"][0]["message"]["content"] or "",
|
||||||
|
total_tokens=resp["usage"]["total_tokens"],
|
||||||
|
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||||
|
completion_tokens=resp["usage"]["completion_tokens"],
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
messages=[
|
||||||
|
AIMessage(content=(_["message"]["content"]) or "")
|
||||||
|
for _ in resp["choices"]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def prepare_client(self, async_version: bool = False):
|
def prepare_client(self, async_version: bool = False):
|
||||||
"""Get the OpenAI client
|
"""Get the OpenAI client
|
||||||
|
|
||||||
|
@ -172,19 +194,7 @@ class BaseChatOpenAI(ChatLLM):
|
||||||
resp = self.openai_response(
|
resp = self.openai_response(
|
||||||
client, messages=input_messages, stream=False, **kwargs
|
client, messages=input_messages, stream=False, **kwargs
|
||||||
).dict()
|
).dict()
|
||||||
|
return self.prepare_output(resp)
|
||||||
output = LLMInterface(
|
|
||||||
candidates=[_["message"]["content"] for _ in resp["choices"]],
|
|
||||||
content=resp["choices"][0]["message"]["content"],
|
|
||||||
total_tokens=resp["usage"]["total_tokens"],
|
|
||||||
prompt_tokens=resp["usage"]["prompt_tokens"],
|
|
||||||
completion_tokens=resp["usage"]["completion_tokens"],
|
|
||||||
messages=[
|
|
||||||
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||||
|
@ -195,18 +205,7 @@ class BaseChatOpenAI(ChatLLM):
|
||||||
client, messages=input_messages, stream=False, **kwargs
|
client, messages=input_messages, stream=False, **kwargs
|
||||||
).dict()
|
).dict()
|
||||||
|
|
||||||
output = LLMInterface(
|
return self.prepare_output(resp)
|
||||||
candidates=[_["message"]["content"] for _ in resp["choices"]],
|
|
||||||
content=resp["choices"][0]["message"]["content"],
|
|
||||||
total_tokens=resp["usage"]["total_tokens"],
|
|
||||||
prompt_tokens=resp["usage"]["prompt_tokens"],
|
|
||||||
completion_tokens=resp["usage"]["completion_tokens"],
|
|
||||||
messages=[
|
|
||||||
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||||
|
@ -338,7 +337,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||||
|
|
||||||
def openai_response(self, client, **kwargs):
|
def openai_response(self, client, **kwargs):
|
||||||
"""Get the openai response"""
|
"""Get the openai response"""
|
||||||
params = {
|
params_ = {
|
||||||
"model": self.azure_deployment,
|
"model": self.azure_deployment,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
|
@ -353,6 +352,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||||
"top_logprobs": self.top_logprobs,
|
"top_logprobs": self.top_logprobs,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
}
|
}
|
||||||
|
params = {k: v for k, v in params_.items() if v is not None}
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
|
||||||
return client.chat.completions.create(**params)
|
return client.chat.completions.create(**params)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user