feat: add first setup screen for LLM & Embedding models (#314) (bump:minor)

* fix: utf-8 txt reader

* fix: revise vectorstore import and make it optional

* feat: add cohere chat model with tool call support

* fix: simplify citation pipeline

* fix: improve citation logic

* fix: improve decompose func call

* fix: revise question rewrite prompt

* fix: revise chat box default placeholder

* fix: add key from ktem to cohere rerank

* fix: conv name suggestion

* fix: ignore default key cohere rerank

* fix: improve test connection UI

* fix: reorder requirements

* feat: add first setup screen

* fix: update requirements

* fix: vectorstore tests

* fix: update cohere version

* fix: relax langchain core version

* fix: add demo mode

* fix: update flowsettings

* fix: typo

* fix: fix bool env passing
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2024-09-22 16:32:23 +07:00
committed by GitHub
parent 0bdb9a32f2
commit 88d577b0cc
27 changed files with 643 additions and 140 deletions

View File

@@ -183,7 +183,7 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
def _get_lc_class(self):
try:
from langchain_community.embeddings import CohereEmbeddings
from langchain_cohere import CohereEmbeddings
except ImportError:
from langchain.embeddings import CohereEmbeddings

View File

@@ -1,4 +1,4 @@
from typing import Iterator, List
from typing import List
from pydantic import BaseModel, Field
@@ -7,53 +7,14 @@ from kotaemon.base.schema import HumanMessage, SystemMessage
from kotaemon.llms import BaseLLM
class FactWithEvidence(BaseModel):
"""Class representing a single statement.
class CiteEvidence(BaseModel):
"""List of evidences (maximum 5) to support the answer."""
Each fact has a body and a list of sources.
If there are multiple facts make sure to break them apart
such that each one only uses a set of sources that are relevant to it.
"""
fact: str = Field(..., description="Body of the sentence, as part of a response")
substring_quote: List[str] = Field(
evidences: List[str] = Field(
...,
description=(
"Each source should be a direct quote from the context, "
"as a substring of the original content"
),
)
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
import regex
minor = quote
major = context
errs_ = 0
s = regex.search(f"({minor}){{e<={errs_}}}", major)
while s is None and errs_ <= errs:
errs_ += 1
s = regex.search(f"({minor}){{e<={errs_}}}", major)
if s is not None:
yield from s.spans()
def get_spans(self, context: str) -> Iterator[str]:
for quote in self.substring_quote:
yield from self._get_span(quote, context)
class QuestionAnswer(BaseModel):
"""A question and its answer as a list of facts each one should have a source.
each sentence contains a body and a list of sources."""
question: str = Field(..., description="Question that was asked")
answer: List[FactWithEvidence] = Field(
...,
description=(
"Body of the answer, each fact should be "
"its separate object with a body and a list of sources"
"as a substring of the original content (max 15 words)."
),
)
@@ -68,7 +29,7 @@ class CitationPipeline(BaseComponent):
return self.invoke(context, question)
def prepare_llm(self, context: str, question: str):
schema = QuestionAnswer.schema()
schema = CiteEvidence.schema()
function = {
"name": schema["title"],
"description": schema["description"],
@@ -76,7 +37,8 @@ class CitationPipeline(BaseComponent):
}
llm_kwargs = {
"tools": [{"type": "function", "function": function}],
"tool_choice": "auto",
"tool_choice": "required",
"tools_pydantic": [CiteEvidence],
}
messages = [
SystemMessage(
@@ -85,7 +47,12 @@ class CitationPipeline(BaseComponent):
"questions with correct and exact citations."
)
),
HumanMessage(content="Answer question using the following context"),
HumanMessage(
content=(
"Answer question using the following context. "
"Use the provided function CiteEvidence() to cite your sources."
)
),
HumanMessage(content=context),
HumanMessage(content=f"Question: {question}"),
HumanMessage(
@@ -103,14 +70,24 @@ class CitationPipeline(BaseComponent):
print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM")
if not llm_output.messages or not llm_output.additional_kwargs.get(
"tool_calls"
):
if not llm_output.additional_kwargs.get("tool_calls"):
return None
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
first_func = llm_output.additional_kwargs["tool_calls"][0]
if "function" in first_func:
# openai and cohere format
function_output = first_func["function"]["arguments"]
else:
# anthropic format
function_output = first_func["args"]
print("CitationPipeline:", function_output)
if isinstance(function_output, str):
output = CiteEvidence.parse_raw(function_output)
else:
output = CiteEvidence.parse_obj(function_output)
except Exception as e:
print(e)
return None
@@ -118,18 +95,4 @@ class CitationPipeline(BaseComponent):
return output
async def ainvoke(self, context: str, question: str):
messages, llm_kwargs = self.prepare_llm(context, question)
try:
print("CitationPipeline: async invoking LLM")
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
print("CitationPipeline: finish async invoking LLM")
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
except Exception as e:
print(e)
return None
return output
raise NotImplementedError()

View File

@@ -10,6 +10,7 @@ from .base import BaseReranking
class CohereReranking(BaseReranking):
model_name: str = "rerank-multilingual-v2.0"
cohere_api_key: str = config("COHERE_API_KEY", "")
use_key_from_ktem: bool = False
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use Cohere Reranker model to re-order documents
@@ -18,9 +19,25 @@ class CohereReranking(BaseReranking):
import cohere
except ImportError:
raise ImportError(
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
"Please install Cohere `pip install cohere` to use Cohere Reranking"
)
# try to get COHERE_API_KEY from embeddings
if not self.cohere_api_key and self.use_key_from_ktem:
try:
from ktem.embeddings.manager import (
embedding_models_manager as embeddings,
)
cohere_model = embeddings.get("cohere")
ktem_cohere_api_key = cohere_model._kwargs.get( # type: ignore
"cohere_api_key"
)
if ktem_cohere_api_key != "your-key":
self.cohere_api_key = ktem_cohere_api_key
except Exception as e:
print("Cannot get Cohere API key from `ktem`", e)
if not self.cohere_api_key:
print("Cohere API key not found. Skipping reranking.")
return documents
@@ -35,7 +52,7 @@ class CohereReranking(BaseReranking):
response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs
)
print("Cohere score", [r.relevance_score for r in response.results])
# print("Cohere score", [r.relevance_score for r in response.results])
for r in response.results:
doc = documents[r.index]
doc.metadata["cohere_reranking_score"] = r.relevance_score

View File

@@ -10,6 +10,7 @@ from .chats import (
LCAnthropicChat,
LCAzureChatOpenAI,
LCChatOpenAI,
LCCohereChat,
LCGeminiChat,
LlamaCppChat,
)
@@ -31,6 +32,7 @@ __all__ = [
"ChatOpenAI",
"LCAnthropicChat",
"LCGeminiChat",
"LCCohereChat",
"LCAzureChatOpenAI",
"LCChatOpenAI",
"LlamaCppChat",

View File

@@ -5,6 +5,7 @@ from .langchain_based import (
LCAzureChatOpenAI,
LCChatMixin,
LCChatOpenAI,
LCCohereChat,
LCGeminiChat,
)
from .llamacpp import LlamaCppChat
@@ -18,6 +19,7 @@ __all__ = [
"ChatOpenAI",
"LCAnthropicChat",
"LCGeminiChat",
"LCCohereChat",
"LCChatOpenAI",
"LCAzureChatOpenAI",
"LCChatMixin",

View File

@@ -18,6 +18,9 @@ class LCChatMixin:
"Please return the relevant Langchain class in in _get_lc_class"
)
def _get_tool_call_kwargs(self):
return {}
def __init__(self, stream: bool = False, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
@@ -56,9 +59,7 @@ class LCChatMixin:
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
pass
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
@@ -83,8 +84,30 @@ class LCChatMixin:
LLMInterface: generated response
"""
input_ = self.prepare_message(messages)
pred = self._obj.generate(messages=[input_], **kwargs)
return self.prepare_response(pred)
if "tools_pydantic" in kwargs:
tools = kwargs.pop(
"tools_pydantic",
)
lc_tool_call = self._obj.bind_tools(tools)
pred = lc_tool_call.invoke(
input_,
**self._get_tool_call_kwargs(),
)
if pred.tool_calls:
tool_calls = pred.tool_calls
else:
tool_calls = pred.additional_kwargs.get("tool_calls", [])
output = LLMInterface(
content="",
additional_kwargs={"tool_calls": tool_calls},
)
else:
pred = self._obj.generate(messages=[input_], **kwargs)
output = self.prepare_response(pred)
return output
async def ainvoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
@@ -235,6 +258,9 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
required=True,
)
def _get_tool_call_kwargs(self):
return {"tool_choice": {"type": "any"}}
def __init__(
self,
api_key: str | None = None,
@@ -291,3 +317,35 @@ class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
raise ImportError("Please install langchain-google-genai")
return ChatGoogleGenerativeAI
class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore
api_key: str = Param(
help="API key (https://dashboard.cohere.com/api-keys)", required=True
)
model_name: str = Param(
help=("Model name to use (https://dashboard.cohere.com/playground/chat)"),
required=True,
)
def __init__(
self,
api_key: str | None = None,
model_name: str | None = None,
temperature: float = 0.7,
**params,
):
super().__init__(
cohere_api_key=api_key,
model_name=model_name,
temperature=temperature,
**params,
)
def _get_lc_class(self):
try:
from langchain_cohere import ChatCohere
except ImportError:
raise ImportError("Please install langchain-cohere")
return ChatCohere

View File

@@ -292,6 +292,9 @@ class ChatOpenAI(BaseChatOpenAI):
def openai_response(self, client, **kwargs):
"""Get the openai response"""
if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic")
params_ = {
"model": self.model,
"temperature": self.temperature,
@@ -360,6 +363,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
def openai_response(self, client, **kwargs):
"""Get the openai response"""
if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic")
params_ = {
"model": self.azure_deployment,
"temperature": self.temperature,

View File

@@ -15,7 +15,7 @@ class TxtReader(BaseReader):
def load_data(
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
) -> list[Document]:
with open(file_path, "r") as f:
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
metadata = extra_info or {}

View File

@@ -73,17 +73,25 @@ class BaseVectorStore(ABC):
class LlamaIndexVectorStore(BaseVectorStore):
_li_class: type[LIVectorStore | BasePydanticVectorStore]
"""Mixin for LlamaIndex based vectorstores"""
_li_class: type[LIVectorStore | BasePydanticVectorStore] | None
def _get_li_class(self):
raise NotImplementedError(
"Please return the relevant LlamaIndex class in in _get_li_class"
)
def __init__(self, *args, **kwargs):
if self._li_class is None:
raise AttributeError(
"Require `_li_class` to set a VectorStore class from LlamarIndex"
)
# get li_class from the method if not set
if not self._li_class:
LIClass = self._get_li_class()
else:
LIClass = self._li_class
from dataclasses import fields
self._client = self._li_class(*args, **kwargs)
self._client = LIClass(*args, **kwargs)
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
@@ -97,6 +105,9 @@ class LlamaIndexVectorStore(BaseVectorStore):
return setattr(self._client, name, value)
def __getattr__(self, name: str) -> Any:
if name == "_li_class":
return super().__getattribute__(name)
return getattr(self._client, name)
def add(

View File

@@ -1,7 +1,5 @@
import os
from typing import Any, Optional, Type, cast
from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore
from typing import Any, Optional, cast
from kotaemon.base import DocumentWithEmbedding
@@ -9,7 +7,20 @@ from .base import LlamaIndexVectorStore
class MilvusVectorStore(LlamaIndexVectorStore):
_li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore
_li_class = None
def _get_li_class(self):
try:
from llama_index.vector_stores.milvus import (
MilvusVectorStore as LIMilvusVectorStore,
)
except ImportError:
raise ImportError(
"Please install missing package: "
"'pip install llama-index-vector-stores-milvus'"
)
return LIMilvusVectorStore
def __init__(
self,
@@ -46,6 +57,10 @@ class MilvusVectorStore(LlamaIndexVectorStore):
dim=dim,
**self._kwargs,
)
from llama_index.vector_stores.milvus import (
MilvusVectorStore as LIMilvusVectorStore,
)
self._client = cast(LIMilvusVectorStore, self._client)
self._inited = True

View File

@@ -1,12 +1,23 @@
from typing import Any, List, Optional, Type, cast
from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore
from typing import Any, List, Optional, cast
from .base import LlamaIndexVectorStore
class QdrantVectorStore(LlamaIndexVectorStore):
_li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore
_li_class = None
def _get_li_class(self):
try:
from llama_index.vector_stores.qdrant import (
QdrantVectorStore as LIQdrantVectorStore,
)
except ImportError:
raise ImportError(
"Please install missing package: "
"'pip install llama-index-vector-stores-qdrant'"
)
return LIQdrantVectorStore
def __init__(
self,
@@ -29,6 +40,10 @@ class QdrantVectorStore(LlamaIndexVectorStore):
client_kwargs=client_kwargs,
**kwargs,
)
from llama_index.vector_stores.qdrant import (
QdrantVectorStore as LIQdrantVectorStore,
)
self._client = cast(LIQdrantVectorStore, self._client)
def delete(self, ids: List[str], **kwargs):

View File

@@ -30,16 +30,15 @@ dependencies = [
"fastapi<=0.112.1",
"gradio>=4.31.0,<4.40",
"html2text==2024.2.26",
"langchain>=0.1.16,<0.2.0",
"langchain-anthropic",
"langchain-community>=0.0.34,<0.1.0",
"langchain>=0.1.16,<0.2.16",
"langchain-community>=0.0.34,<=0.2.11",
"langchain-openai>=0.1.4,<0.2.0",
"langchain-anthropic",
"langchain-cohere>=0.2.4,<0.3.0",
"llama-hub>=0.0.79,<0.1.0",
"llama-index>=0.10.40,<0.11.0",
"llama-index-vector-stores-chroma>=0.1.9",
"llama-index-vector-stores-lancedb",
"llama-index-vector-stores-milvus",
"llama-index-vector-stores-qdrant",
"openai>=1.23.6,<2",
"openpyxl>=3.1.2,<3.2",
"opentelemetry-exporter-otlp-proto-grpc>=1.25.0", # https://github.com/chroma-core/chroma/issues/2571
@@ -75,6 +74,9 @@ adv = [
"llama-cpp-python<0.2.8",
"sentence-transformers",
"wikipedia>=1.4.0,<1.5",
"llama-index>=0.10.40,<0.11.0",
"llama-index-vector-stores-milvus",
"llama-index-vector-stores-qdrant",
]
dev = [
"black",

View File

@@ -135,7 +135,7 @@ def test_lchuggingface_embeddings(
@skip_when_cohere_not_installed
@patch(
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
"langchain_cohere.CohereEmbeddings.embed_documents",
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
)
def test_lccohere_embeddings(langchain_cohere_embedding_call):