fix: add guidance parameters for LC wrapper models (#255)

* fix: add docstring to LC wrapper models

* fix: fix metadata passing with LC embedding wrapper
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-09-09 14:15:34 +07:00 committed by GitHub
parent ce489725d8
commit 96d2086017
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 54 additions and 18 deletions

View File

@ -208,6 +208,7 @@ KH_EMBEDDINGS["cohere"] = {
"__type__": "kotaemon.embeddings.LCCohereEmbeddings", "__type__": "kotaemon.embeddings.LCCohereEmbeddings",
"model": "embed-multilingual-v2.0", "model": "embed-multilingual-v2.0",
"cohere_api_key": "your-key", "cohere_api_key": "your-key",
"user_agent": "default",
}, },
"default": False, "default": False,
} }

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from kotaemon.base import Document, DocumentWithEmbedding from kotaemon.base import DocumentWithEmbedding, Param
from .base import BaseEmbeddings from .base import BaseEmbeddings
@ -19,25 +19,14 @@ class LCEmbeddingMixin:
super().__init__() super().__init__()
def run(self, text): def run(self, text):
input_: list[str] = [] input_docs = self.prepare_input(text)
if not isinstance(text, list): input_ = [doc.text for doc in input_docs]
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
embeddings = self._obj.embed_documents(input_) embeddings = self._obj.embed_documents(input_)
return [ return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding) DocumentWithEmbedding(content=doc, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings) for doc, each_embedding in zip(input_docs, embeddings)
] ]
def __repr__(self): def __repr__(self):
@ -162,6 +151,20 @@ class LCAzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings): class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters""" """Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
cohere_api_key: str = Param(
help="API key (https://dashboard.cohere.com/api-keys)",
default=None,
required=True,
)
model: str = Param(
help="Model name to use (https://docs.cohere.com/docs/models)",
default=None,
required=True,
)
user_agent: str = Param(
help="User agent (leave default)", default="default", required=True
)
def __init__( def __init__(
self, self,
model: str = "embed-english-v2.0", model: str = "embed-english-v2.0",
@ -190,6 +193,15 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings): class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters""" """Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
model_name: str = Param(
help=(
"Model name to use (https://huggingface.co/models?"
"pipeline_tag=sentence-similarity&sort=trending)"
),
default=None,
required=True,
)
def __init__( def __init__(
self, self,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from typing import AsyncGenerator, Iterator from typing import AsyncGenerator, Iterator
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
from .base import ChatLLM from .base import ChatLLM
@ -224,6 +224,17 @@ class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
api_key: str = Param(
help="API key (https://console.anthropic.com/settings/keys)", required=True
)
model_name: str = Param(
help=(
"Model name to use "
"(https://docs.anthropic.com/en/docs/about-claude/models)"
),
required=True,
)
def __init__( def __init__(
self, self,
api_key: str | None = None, api_key: str | None = None,
@ -248,6 +259,17 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
api_key: str = Param(
help="API key (https://aistudio.google.com/app/apikey)", required=True
)
model_name: str = Param(
help=(
"Model name to use (https://cloud.google"
".com/vertex-ai/generative-ai/docs/learn/models)"
),
required=True,
)
def __init__( def __init__(
self, self,
api_key: str | None = None, api_key: str | None = None,

View File

@ -50,6 +50,7 @@ class EmbeddingManager:
} }
if item.default: if item.default:
self._default = item.name self._default = item.name
self._models["default"] = self._models[item.name]
def load_vendors(self): def load_vendors(self):
from kotaemon.embeddings import ( from kotaemon.embeddings import (

View File

@ -344,7 +344,7 @@ class FileIndex(BaseIndex):
def get_admin_settings(cls): def get_admin_settings(cls):
from ktem.embeddings.manager import embedding_models_manager from ktem.embeddings.manager import embedding_models_manager
embedding_default = embedding_models_manager.get_default_name() embedding_default = "default"
embedding_choices = list(embedding_models_manager.options().keys()) embedding_choices = list(embedding_models_manager.options().keys())
return { return {