Use new Langchain's dedicated Azure OpenAI embedding class (#76)
* Use new Langchain's dedicated Azure OpenAI embedding class * Update test
This commit is contained in:
parent
b159897ac6
commit
b52f312d8e
|
@ -1,4 +1,4 @@
|
||||||
from langchain.embeddings import OpenAIEmbeddings as LCOpenAIEmbeddings
|
from langchain import embeddings as lcembeddings
|
||||||
|
|
||||||
from .base import LangchainEmbeddings
|
from .base import LangchainEmbeddings
|
||||||
|
|
||||||
|
@ -9,23 +9,13 @@ class OpenAIEmbeddings(LangchainEmbeddings):
|
||||||
This method is wrapped around the Langchain OpenAIEmbeddings class.
|
This method is wrapped around the Langchain OpenAIEmbeddings class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_lc_class = LCOpenAIEmbeddings
|
_lc_class = lcembeddings.OpenAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIEmbeddings(LangchainEmbeddings):
|
class AzureOpenAIEmbeddings(LangchainEmbeddings):
|
||||||
"""Azure OpenAI embeddings.
|
"""Azure OpenAI embeddings.
|
||||||
|
|
||||||
This method is wrapped around the Langchain OpenAIEmbeddings class.
|
This method is wrapped around the Langchain AzureOpenAIEmbeddings class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_lc_class = LCOpenAIEmbeddings
|
_lc_class = lcembeddings.AzureOpenAIEmbeddings
|
||||||
|
|
||||||
def __init__(self, **params):
|
|
||||||
params["openai_api_type"] = "azure"
|
|
||||||
|
|
||||||
# openai.error.InvalidRequestError: Too many inputs. The max number of
|
|
||||||
# inputs is 16. We hope to increase the number of inputs per request
|
|
||||||
# soon. Please contact us through an Azure support request at:
|
|
||||||
# https://go.microsoft.com/fwlink/?linkid=2213926 for further questions.
|
|
||||||
params["chunk_size"] = 16
|
|
||||||
super().__init__(**params)
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ def test_azureopenai_embeddings_raw(openai_embedding_call):
|
||||||
model = AzureOpenAIEmbeddings(
|
model = AzureOpenAIEmbeddings(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="embedding-deployment",
|
deployment="embedding-deployment",
|
||||||
openai_api_base="https://test.openai.azure.com/",
|
azure_endpoint="https://test.openai.azure.com/",
|
||||||
openai_api_key="some-key",
|
openai_api_key="some-key",
|
||||||
)
|
)
|
||||||
output = model("Hello world")
|
output = model("Hello world")
|
||||||
|
@ -39,7 +39,7 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
|
||||||
model = AzureOpenAIEmbeddings(
|
model = AzureOpenAIEmbeddings(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="embedding-deployment",
|
deployment="embedding-deployment",
|
||||||
openai_api_base="https://test.openai.azure.com/",
|
azure_endpoint="https://test.openai.azure.com/",
|
||||||
openai_api_key="some-key",
|
openai_api_key="some-key",
|
||||||
)
|
)
|
||||||
output = model(["Hello world", "Goodbye world"])
|
output = model(["Hello world", "Goodbye world"])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user