fix: add guidance parameters for LC wrapper models (#255)
* fix: add docstring to LC wrapper models * fix: fix metadata passing with LC embedding wrapper
This commit is contained in:
parent
ce489725d8
commit
96d2086017
|
@ -208,6 +208,7 @@ KH_EMBEDDINGS["cohere"] = {
|
||||||
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
|
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
|
||||||
"model": "embed-multilingual-v2.0",
|
"model": "embed-multilingual-v2.0",
|
||||||
"cohere_api_key": "your-key",
|
"cohere_api_key": "your-key",
|
||||||
|
"user_agent": "default",
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from kotaemon.base import Document, DocumentWithEmbedding
|
from kotaemon.base import DocumentWithEmbedding, Param
|
||||||
|
|
||||||
from .base import BaseEmbeddings
|
from .base import BaseEmbeddings
|
||||||
|
|
||||||
|
@ -19,25 +19,14 @@ class LCEmbeddingMixin:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def run(self, text):
|
def run(self, text):
|
||||||
input_: list[str] = []
|
input_docs = self.prepare_input(text)
|
||||||
if not isinstance(text, list):
|
input_ = [doc.text for doc in input_docs]
|
||||||
text = [text]
|
|
||||||
|
|
||||||
for item in text:
|
|
||||||
if isinstance(item, str):
|
|
||||||
input_.append(item)
|
|
||||||
elif isinstance(item, Document):
|
|
||||||
input_.append(item.text)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid input type {type(item)}, should be str or Document"
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = self._obj.embed_documents(input_)
|
embeddings = self._obj.embed_documents(input_)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
|
DocumentWithEmbedding(content=doc, embedding=each_embedding)
|
||||||
for each_text, each_embedding in zip(input_, embeddings)
|
for doc, each_embedding in zip(input_docs, embeddings)
|
||||||
]
|
]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -162,6 +151,20 @@ class LCAzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
|
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
|
||||||
|
|
||||||
|
cohere_api_key: str = Param(
|
||||||
|
help="API key (https://dashboard.cohere.com/api-keys)",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
model: str = Param(
|
||||||
|
help="Model name to use (https://docs.cohere.com/docs/models)",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
user_agent: str = Param(
|
||||||
|
help="User agent (leave default)", default="default", required=True
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str = "embed-english-v2.0",
|
model: str = "embed-english-v2.0",
|
||||||
|
@ -190,6 +193,15 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
|
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
|
||||||
|
|
||||||
|
model_name: str = Param(
|
||||||
|
help=(
|
||||||
|
"Model name to use (https://huggingface.co/models?"
|
||||||
|
"pipeline_tag=sentence-similarity&sort=trending)"
|
||||||
|
),
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, Iterator
|
from typing import AsyncGenerator, Iterator
|
||||||
|
|
||||||
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
|
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
|
||||||
|
|
||||||
from .base import ChatLLM
|
from .base import ChatLLM
|
||||||
|
|
||||||
|
@ -224,6 +224,17 @@ class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
|
class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
|
||||||
|
api_key: str = Param(
|
||||||
|
help="API key (https://console.anthropic.com/settings/keys)", required=True
|
||||||
|
)
|
||||||
|
model_name: str = Param(
|
||||||
|
help=(
|
||||||
|
"Model name to use "
|
||||||
|
"(https://docs.anthropic.com/en/docs/about-claude/models)"
|
||||||
|
),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
@ -248,6 +259,17 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
|
class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
|
||||||
|
api_key: str = Param(
|
||||||
|
help="API key (https://aistudio.google.com/app/apikey)", required=True
|
||||||
|
)
|
||||||
|
model_name: str = Param(
|
||||||
|
help=(
|
||||||
|
"Model name to use (https://cloud.google"
|
||||||
|
".com/vertex-ai/generative-ai/docs/learn/models)"
|
||||||
|
),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
|
|
@ -50,6 +50,7 @@ class EmbeddingManager:
|
||||||
}
|
}
|
||||||
if item.default:
|
if item.default:
|
||||||
self._default = item.name
|
self._default = item.name
|
||||||
|
self._models["default"] = self._models[item.name]
|
||||||
|
|
||||||
def load_vendors(self):
|
def load_vendors(self):
|
||||||
from kotaemon.embeddings import (
|
from kotaemon.embeddings import (
|
||||||
|
|
|
@ -344,7 +344,7 @@ class FileIndex(BaseIndex):
|
||||||
def get_admin_settings(cls):
|
def get_admin_settings(cls):
|
||||||
from ktem.embeddings.manager import embedding_models_manager
|
from ktem.embeddings.manager import embedding_models_manager
|
||||||
|
|
||||||
embedding_default = embedding_models_manager.get_default_name()
|
embedding_default = "default"
|
||||||
embedding_choices = list(embedding_models_manager.options().keys())
|
embedding_choices = list(embedding_models_manager.options().keys())
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user