Enable fastembed as a local embedding vendor (#12)
* Prepend all Langchain-based embeddings with LC * Provide vanilla OpenAI embeddings * Add test for AzureOpenAIEmbeddings and OpenAIEmbeddings * Incorporate fastembed --------- Co-authored-by: ian_Cin <ian@cinnamon.is>
This commit is contained in:
parent
8001c86b16
commit
e75354d410
5
.github/workflows/unit-test.yaml
vendored
5
.github/workflows/unit-test.yaml
vendored
|
@ -100,6 +100,11 @@ jobs:
|
|||
path: ${{ env.pythonLocation }}
|
||||
key: ${{ steps.restore-dependencies.outputs.cache-primary-key }}
|
||||
|
||||
- name: Install OS-based packages
|
||||
run: |
|
||||
sudo apt update -qqy
|
||||
sudo apt install -y poppler-utils libpoppler-dev tesseract-ocr
|
||||
|
||||
- name: Test kotaemon with pytest
|
||||
run: |
|
||||
pip show pytest
|
||||
|
|
|
@ -48,7 +48,10 @@ class Document(BaseDocument):
|
|||
# default text indicating this document only contains embedding
|
||||
kwargs["text"] = "<EMBEDDING>"
|
||||
elif isinstance(content, Document):
|
||||
kwargs = content.dict()
|
||||
# TODO: simplify the Document class
|
||||
temp_ = content.dict()
|
||||
temp_.update(kwargs)
|
||||
kwargs = temp_
|
||||
else:
|
||||
kwargs["content"] = content
|
||||
if content:
|
||||
|
|
|
@ -1,17 +1,22 @@
|
|||
from .base import BaseEmbeddings
|
||||
from .endpoint_based import EndpointEmbeddings
|
||||
from .fastembed import FastEmbedEmbeddings
|
||||
from .langchain_based import (
|
||||
AzureOpenAIEmbeddings,
|
||||
CohereEmbdeddings,
|
||||
HuggingFaceEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
LCAzureOpenAIEmbeddings,
|
||||
LCCohereEmbdeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
LCOpenAIEmbeddings,
|
||||
)
|
||||
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddings",
|
||||
"EndpointEmbeddings",
|
||||
"LCOpenAIEmbeddings",
|
||||
"LCAzureOpenAIEmbeddings",
|
||||
"LCCohereEmbdeddings",
|
||||
"LCHuggingFaceEmbeddings",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"CohereEmbdeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
"FastEmbedEmbeddings",
|
||||
]
|
||||
|
|
|
@ -1,13 +1,29 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding
|
||||
|
||||
|
||||
class BaseEmbeddings(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, text: str | list[str] | Document | list[Document]
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
...
|
||||
return self.invoke(text, *args, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def ainvoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_input(
|
||||
self, text: str | list[str] | Document | list[Document]
|
||||
) -> list[Document]:
|
||||
if isinstance(text, (str, Document)):
|
||||
return [Document(content=text)]
|
||||
elif isinstance(text, list):
|
||||
return [Document(content=_) for _ in text]
|
||||
return text
|
||||
|
|
68
libs/kotaemon/kotaemon/embeddings/fastembed.py
Normal file
68
libs/kotaemon/kotaemon/embeddings/fastembed.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from kotaemon.base import Document, DocumentWithEmbedding, Param
|
||||
|
||||
from .base import BaseEmbeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
|
||||
class FastEmbedEmbeddings(BaseEmbeddings):
|
||||
"""Utilize fastembed library for embeddings locally without GPU.
|
||||
|
||||
Supported model: https://qdrant.github.io/fastembed/examples/Supported_Models/
|
||||
Code: https://github.com/qdrant/fastembed
|
||||
"""
|
||||
|
||||
model_name: str = Param(
|
||||
"BAAI/bge-small-en-v1.5",
|
||||
help=(
|
||||
"Model name for fastembed. "
|
||||
"Supported model: "
|
||||
"https://qdrant.github.io/fastembed/examples/Supported_Models/"
|
||||
),
|
||||
)
|
||||
batch_size: int = Param(
|
||||
256,
|
||||
help="Batch size for embeddings. Higher values use more memory, but are faster",
|
||||
)
|
||||
parallel: Optional[int] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Number of threads to use for embeddings. "
|
||||
"If > 1, data-parallel encoding will be used. "
|
||||
"If 0, use all available CPUs. "
|
||||
"If None, use default onnxruntime threading. "
|
||||
"Defaults to None"
|
||||
),
|
||||
)
|
||||
|
||||
@Param.auto()
|
||||
def client_(self) -> "TextEmbedding":
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
return TextEmbedding(model_name=self.model_name)
|
||||
|
||||
def invoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
input_ = self.prepare_input(text)
|
||||
embeddings = self.client_.embed(
|
||||
[_.content for _ in input_],
|
||||
batch_size=self.batch_size,
|
||||
parallel=self.parallel,
|
||||
)
|
||||
return [
|
||||
DocumentWithEmbedding(
|
||||
content=doc,
|
||||
embedding=list(embedding),
|
||||
)
|
||||
for doc, embedding in zip(input_, embeddings)
|
||||
]
|
||||
|
||||
async def ainvoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
"""Fastembed does not support async API."""
|
||||
return self.invoke(text, *args, **kwargs)
|
|
@ -97,7 +97,7 @@ class LCEmbeddingMixin:
|
|||
raise ValueError(f"Invalid param {path}")
|
||||
|
||||
|
||||
class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
class LCOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's OpenAI embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
|
@ -129,7 +129,7 @@ class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
return OpenAIEmbeddings
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
class LCAzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's AzureOpenAI embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
|
@ -159,7 +159,7 @@ class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
return AzureOpenAIEmbeddings
|
||||
|
||||
|
||||
class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
class LCCohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
|
@ -187,7 +187,7 @@ class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
return CohereEmbeddings
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
|
|
183
libs/kotaemon/kotaemon/embeddings/openai.py
Normal file
183
libs/kotaemon/kotaemon/embeddings/openai.py
Normal file
|
@ -0,0 +1,183 @@
|
|||
from typing import Optional
|
||||
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
from kotaemon.base import Param
|
||||
|
||||
from .base import BaseEmbeddings, Document, DocumentWithEmbedding
|
||||
|
||||
|
||||
class BaseOpenAIEmbeddings(BaseEmbeddings):
|
||||
"""Base interface for OpenAI embedding model, using the openai library.
|
||||
|
||||
This class exposes the parameters in resources.Chat. To subclass this class:
|
||||
|
||||
- Implement the `prepare_client` method to return the OpenAI client
|
||||
- Implement the `openai_response` method to return the OpenAI response
|
||||
- Implement the params relate to the OpenAI client
|
||||
"""
|
||||
|
||||
_dependencies = ["openai"]
|
||||
|
||||
api_key: str = Param(help="API key", required=True)
|
||||
timeout: Optional[float] = Param(None, help="Timeout for the API request.")
|
||||
max_retries: Optional[int] = Param(
|
||||
None, help="Maximum number of retries for the API request."
|
||||
)
|
||||
|
||||
dimensions: Optional[int] = Param(
|
||||
None,
|
||||
help=(
|
||||
"The number of dimensions the resulting output embeddings should have. "
|
||||
"Only supported in `text-embedding-3` and later models."
|
||||
),
|
||||
)
|
||||
|
||||
@Param.auto(depends_on=["max_retries"])
|
||||
def max_retries_(self):
|
||||
if self.max_retries is None:
|
||||
from openai._constants import DEFAULT_MAX_RETRIES
|
||||
|
||||
return DEFAULT_MAX_RETRIES
|
||||
return self.max_retries
|
||||
|
||||
def prepare_client(self, async_version: bool = False):
|
||||
"""Get the OpenAI client
|
||||
|
||||
Args:
|
||||
async_version (bool): Whether to get the async version of the client
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
raise NotImplementedError
|
||||
|
||||
def invoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
input_ = self.prepare_input(text)
|
||||
client = self.prepare_client(async_version=False)
|
||||
resp = self.openai_response(
|
||||
client, input=[_.text for _ in input_], **kwargs
|
||||
).dict()
|
||||
output_ = sorted(resp["data"], key=lambda x: x["index"])
|
||||
return [
|
||||
DocumentWithEmbedding(embedding=o["embedding"], content=i)
|
||||
for i, o in zip(input_, output_)
|
||||
]
|
||||
|
||||
async def ainvoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
input_ = self.prepare_input(text)
|
||||
client = self.prepare_client(async_version=True)
|
||||
resp = await self.openai_response(
|
||||
client, input=[_.text for _ in input_], **kwargs
|
||||
).dict()
|
||||
output_ = sorted(resp["data"], key=lambda x: x["index"])
|
||||
return [
|
||||
DocumentWithEmbedding(embedding=o["embedding"], content=i)
|
||||
for i, o in zip(input_, output_)
|
||||
]
|
||||
|
||||
|
||||
class OpenAIEmbeddings(BaseOpenAIEmbeddings):
|
||||
"""OpenAI chat model"""
|
||||
|
||||
base_url: Optional[str] = Param(None, help="OpenAI base URL")
|
||||
organization: Optional[str] = Param(None, help="OpenAI organization")
|
||||
model: str = Param(
|
||||
help=(
|
||||
"ID of the model to use. You can go to [Model overview](https://platform."
|
||||
"openai.com/docs/models/overview) to see the available models."
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
|
||||
def prepare_client(self, async_version: bool = False):
|
||||
"""Get the OpenAI client
|
||||
|
||||
Args:
|
||||
async_version (bool): Whether to get the async version of the client
|
||||
"""
|
||||
params = {
|
||||
"api_key": self.api_key,
|
||||
"organization": self.organization,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries_,
|
||||
}
|
||||
if async_version:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
return AsyncOpenAI(**params)
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(**params)
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params: dict = {
|
||||
"model": self.model,
|
||||
}
|
||||
if self.dimensions:
|
||||
params["dimensions"] = self.dimensions
|
||||
params.update(kwargs)
|
||||
|
||||
return client.embeddings.create(**params)
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
|
||||
azure_endpoint: str = Param(
|
||||
help=(
|
||||
"HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, "
|
||||
"azure_deployment, and api_version parameters are used to construct "
|
||||
"the full URL for the Azure OpenAI model."
|
||||
)
|
||||
)
|
||||
azure_deployment: str = Param(help="Azure deployment name", required=True)
|
||||
api_version: str = Param(help="Azure model version", required=True)
|
||||
azure_ad_token: Optional[str] = Param(None, help="Azure AD token")
|
||||
azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider")
|
||||
|
||||
@Param.auto(depends_on=["azure_ad_token_provider"])
|
||||
def azure_ad_token_provider_(self):
|
||||
if isinstance(self.azure_ad_token_provider, str):
|
||||
return import_dotted_string(self.azure_ad_token_provider, safe=False)
|
||||
|
||||
def prepare_client(self, async_version: bool = False):
|
||||
"""Get the OpenAI client
|
||||
|
||||
Args:
|
||||
async_version (bool): Whether to get the async version of the client
|
||||
"""
|
||||
params = {
|
||||
"azure_endpoint": self.azure_endpoint,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"azure_ad_token": self.azure_ad_token,
|
||||
"azure_ad_token_provider": self.azure_ad_token_provider_,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries_,
|
||||
}
|
||||
if async_version:
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
return AsyncAzureOpenAI(**params)
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
return AzureOpenAI(**params)
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params: dict = {
|
||||
"model": self.azure_deployment,
|
||||
}
|
||||
if self.dimensions:
|
||||
params["dimensions"] = self.dimensions
|
||||
params.update(kwargs)
|
||||
|
||||
return client.embeddings.create(**params)
|
|
@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
|
|||
# metadata and dependencies
|
||||
[project]
|
||||
name = "kotaemon"
|
||||
version = "0.3.9"
|
||||
version = "0.3.10"
|
||||
requires-python = ">= 3.10"
|
||||
description = "Kotaemon core library for AI development."
|
||||
dependencies = [
|
||||
|
@ -61,6 +61,7 @@ adv = [
|
|||
"elasticsearch",
|
||||
"llama-cpp-python",
|
||||
"pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
|
||||
"fastembed",
|
||||
]
|
||||
dev = [
|
||||
"ipython",
|
||||
|
|
|
@ -2,7 +2,7 @@ import tempfile
|
|||
from typing import List
|
||||
|
||||
from kotaemon.base import BaseComponent, LLMInterface, lazy
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.embeddings import LCAzureOpenAIEmbeddings
|
||||
from kotaemon.indices import VectorRetrieval
|
||||
from kotaemon.llms import AzureOpenAI
|
||||
from kotaemon.storages import ChromaVectorStore
|
||||
|
@ -20,7 +20,7 @@ class Pipeline(BaseComponent):
|
|||
|
||||
retrieving_pipeline: VectorRetrieval = VectorRetrieval.withx(
|
||||
vector_store=lazy(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())),
|
||||
embedding=AzureOpenAIEmbeddings.withx(
|
||||
embedding=LCAzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
|
|
|
@ -2,18 +2,62 @@ import json
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.embeddings import (
|
||||
AzureOpenAIEmbeddings,
|
||||
CohereEmbdeddings,
|
||||
HuggingFaceEmbeddings,
|
||||
FastEmbedEmbeddings,
|
||||
LCAzureOpenAIEmbeddings,
|
||||
LCCohereEmbdeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
)
|
||||
|
||||
with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f:
|
||||
openai_embedding_batch = json.load(f)
|
||||
openai_embedding_batch = CreateEmbeddingResponse.model_validate(json.load(f))
|
||||
|
||||
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
||||
openai_embedding = json.load(f)
|
||||
openai_embedding = CreateEmbeddingResponse.model_validate(json.load(f))
|
||||
|
||||
|
||||
def assert_embedding_result(output):
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.resources.embeddings.Embeddings.create",
|
||||
side_effect=lambda *args, **kwargs: openai_embedding,
|
||||
)
|
||||
def test_lcazureopenai_embeddings_raw(openai_embedding_call):
|
||||
model = LCAzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
)
|
||||
output = model("Hello world")
|
||||
assert_embedding_result(output)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.resources.embeddings.Embeddings.create",
|
||||
side_effect=lambda *args, **kwargs: openai_embedding_batch,
|
||||
)
|
||||
def test_lcazureopenai_embeddings_batch_raw(openai_embedding_call):
|
||||
model = LCAzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
)
|
||||
output = model(["Hello world", "Goodbye world"])
|
||||
assert_embedding_result(output)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
@patch(
|
||||
|
@ -22,16 +66,13 @@ with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
|||
)
|
||||
def test_azureopenai_embeddings_raw(openai_embedding_call):
|
||||
model = AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
api_key="some-key",
|
||||
api_version="version",
|
||||
azure_deployment="text-embedding-ada-002",
|
||||
)
|
||||
output = model("Hello world")
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
assert_embedding_result(output)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
|
@ -41,16 +82,41 @@ def test_azureopenai_embeddings_raw(openai_embedding_call):
|
|||
)
|
||||
def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
|
||||
model = AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
api_key="some-key",
|
||||
api_version="version",
|
||||
azure_deployment="text-embedding-ada-002",
|
||||
)
|
||||
output = model(["Hello world", "Goodbye world"])
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
assert_embedding_result(output)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.resources.embeddings.Embeddings.create",
|
||||
side_effect=lambda *args, **kwargs: openai_embedding,
|
||||
)
|
||||
def test_openai_embeddings_raw(openai_embedding_call):
|
||||
model = OpenAIEmbeddings(
|
||||
api_key="some-key",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
output = model("Hello world")
|
||||
assert_embedding_result(output)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.resources.embeddings.Embeddings.create",
|
||||
side_effect=lambda *args, **kwargs: openai_embedding_batch,
|
||||
)
|
||||
def test_openai_embeddings_batch_raw(openai_embedding_call):
|
||||
model = OpenAIEmbeddings(
|
||||
api_key="some-key",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
output = model(["Hello world", "Goodbye world"])
|
||||
assert_embedding_result(output)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
|
@ -62,20 +128,17 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
|
|||
"langchain.embeddings.huggingface.HuggingFaceBgeEmbeddings.embed_documents",
|
||||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||
)
|
||||
def test_huggingface_embeddings(
|
||||
def test_lchuggingface_embeddings(
|
||||
langchain_huggingface_embedding_call, sentence_transformers_init
|
||||
):
|
||||
model = HuggingFaceEmbeddings(
|
||||
model = LCHuggingFaceEmbeddings(
|
||||
model_name="intfloat/multilingual-e5-large",
|
||||
model_kwargs={"device": "cpu"},
|
||||
encode_kwargs={"normalize_embeddings": False},
|
||||
)
|
||||
|
||||
output = model("Hello World")
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
assert_embedding_result(output)
|
||||
sentence_transformers_init.assert_called()
|
||||
langchain_huggingface_embedding_call.assert_called()
|
||||
|
||||
|
@ -84,14 +147,17 @@ def test_huggingface_embeddings(
|
|||
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
|
||||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||
)
|
||||
def test_cohere_embeddings(langchain_cohere_embedding_call):
|
||||
model = CohereEmbdeddings(
|
||||
def test_lccohere_embeddings(langchain_cohere_embedding_call):
|
||||
model = LCCohereEmbdeddings(
|
||||
model="embed-english-light-v2.0", cohere_api_key="my-api-key"
|
||||
)
|
||||
|
||||
output = model("Hello World")
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
assert_embedding_result(output)
|
||||
langchain_cohere_embedding_call.assert_called()
|
||||
|
||||
|
||||
def test_fastembed_embeddings():
|
||||
model = FastEmbedEmbeddings()
|
||||
output = model("Hello World")
|
||||
assert_embedding_result(output)
|
||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from openai.resources.embeddings import Embeddings
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.embeddings import LCAzureOpenAIEmbeddings
|
||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
|
||||
|
||||
|
@ -22,7 +22,7 @@ def mock_openai_embedding(monkeypatch):
|
|||
def test_indexing(mock_openai_embedding, tmp_path):
|
||||
db = ChromaVectorStore(path=str(tmp_path))
|
||||
doc_store = InMemoryDocumentStore()
|
||||
embedding = AzureOpenAIEmbeddings(
|
||||
embedding = LCAzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
|
@ -42,7 +42,7 @@ def test_indexing(mock_openai_embedding, tmp_path):
|
|||
def test_retrieving(mock_openai_embedding, tmp_path):
|
||||
db = ChromaVectorStore(path=str(tmp_path))
|
||||
doc_store = InMemoryDocumentStore()
|
||||
embedding = AzureOpenAIEmbeddings(
|
||||
embedding = LCAzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
|
|
|
@ -6,7 +6,7 @@ from openai.resources.embeddings import Embeddings
|
|||
|
||||
from kotaemon.agents.tools import ComponentTool, GoogleSearchTool, WikipediaTool
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.embeddings import LCAzureOpenAIEmbeddings
|
||||
from kotaemon.indices.vectorindex import VectorIndexing, VectorRetrieval
|
||||
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
|
||||
|
||||
|
@ -38,7 +38,7 @@ def test_wikipedia_tool():
|
|||
def test_pipeline_tool(mock_openai_embedding, tmp_path):
|
||||
db = ChromaVectorStore(path=str(tmp_path))
|
||||
doc_store = InMemoryDocumentStore()
|
||||
embedding = AzureOpenAIEmbeddings(
|
||||
embedding = LCAzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
|
|
|
@ -57,7 +57,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
|
|||
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
|
||||
KH_EMBEDDINGS["azure"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
|
||||
"__type__": "kotaemon.embeddings.LCAzureOpenAIEmbeddings",
|
||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
"api_version": config("OPENAI_API_VERSION", default="")
|
||||
|
@ -87,7 +87,7 @@ if config("OPENAI_API_KEY", default=""):
|
|||
if len(KH_EMBEDDINGS) < 1:
|
||||
KH_EMBEDDINGS["openai"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||
"__type__": "kotaemon.embeddings.LCOpenAIEmbeddings",
|
||||
"base_url": config("OPENAI_API_BASE", default="")
|
||||
or "https://api.openai.com/v1",
|
||||
"api_key": config("OPENAI_API_KEY", default=""),
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List
|
|||
|
||||
from kotaemon.base import BaseComponent, Document, LLMInterface, Node, Param, lazy
|
||||
from kotaemon.contribs.promptui.logs import ResultLog
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.embeddings import LCAzureOpenAIEmbeddings
|
||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||
from kotaemon.llms import LCAzureChatOpenAI
|
||||
from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
|
||||
|
@ -47,7 +47,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
|||
VectorRetrieval.withx(
|
||||
vector_store=lazy(ChromaVectorStore).withx(path="./tmp"),
|
||||
doc_store=lazy(SimpleFileDocumentStore).withx(path="docstore.json"),
|
||||
embedding=AzureOpenAIEmbeddings.withx(
|
||||
embedding=LCAzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||
|
@ -82,7 +82,7 @@ class IndexingPipeline(VectorIndexing):
|
|||
lazy(SimpleFileDocumentStore).withx(path="docstore.json"),
|
||||
ignore_ui=True,
|
||||
)
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
embedding: LCAzureOpenAIEmbeddings = LCAzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||
|
|
Loading…
Reference in New Issue
Block a user