Use new Langchain's dedicated Azure OpenAI embedding class (#76)

* Use new Langchain's dedicated Azure OpenAI embedding class

* Update test
This commit is contained in:
Nguyen Trung Duc (john) 2023-11-15 14:46:32 +07:00 committed by GitHub
parent b159897ac6
commit b52f312d8e
2 changed files with 6 additions and 16 deletions

View File

@ -1,4 +1,4 @@
from langchain.embeddings import OpenAIEmbeddings as LCOpenAIEmbeddings from langchain import embeddings as lcembeddings
from .base import LangchainEmbeddings from .base import LangchainEmbeddings
@ -9,23 +9,13 @@ class OpenAIEmbeddings(LangchainEmbeddings):
This method is wrapped around the Langchain OpenAIEmbeddings class. This method is wrapped around the Langchain OpenAIEmbeddings class.
""" """
_lc_class = LCOpenAIEmbeddings _lc_class = lcembeddings.OpenAIEmbeddings
class AzureOpenAIEmbeddings(LangchainEmbeddings): class AzureOpenAIEmbeddings(LangchainEmbeddings):
"""Azure OpenAI embeddings. """Azure OpenAI embeddings.
This method is wrapped around the Langchain OpenAIEmbeddings class. This method is wrapped around the Langchain AzureOpenAIEmbeddings class.
""" """
_lc_class = LCOpenAIEmbeddings _lc_class = lcembeddings.AzureOpenAIEmbeddings
def __init__(self, **params):
params["openai_api_type"] = "azure"
# openai.error.InvalidRequestError: Too many inputs. The max number of
# inputs is 16. We hope to increase the number of inputs per request
# soon. Please contact us through an Azure support request at:
# https://go.microsoft.com/fwlink/?linkid=2213926 for further questions.
params["chunk_size"] = 16
super().__init__(**params)

View File

@ -21,7 +21,7 @@ def test_azureopenai_embeddings_raw(openai_embedding_call):
model = AzureOpenAIEmbeddings( model = AzureOpenAIEmbeddings(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="embedding-deployment", deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/", azure_endpoint="https://test.openai.azure.com/",
openai_api_key="some-key", openai_api_key="some-key",
) )
output = model("Hello world") output = model("Hello world")
@ -39,7 +39,7 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
model = AzureOpenAIEmbeddings( model = AzureOpenAIEmbeddings(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="embedding-deployment", deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/", azure_endpoint="https://test.openai.azure.com/",
openai_api_key="some-key", openai_api_key="some-key",
) )
output = model(["Hello world", "Goodbye world"]) output = model(["Hello world", "Goodbye world"])