Feat/local endpoint llm (#148)
* serve local model in a different process from the app --------- Co-authored-by: albert <albert@cinnamon.is> Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
|
||||
|
||||
from pydantic import Extra
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from kotaemon.base import LLMInterface
|
||||
|
||||
@@ -238,7 +238,7 @@ class AgentFinish(NamedTuple):
|
||||
log: str
|
||||
|
||||
|
||||
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
|
||||
class AgentOutput(LLMInterface):
|
||||
"""Output from an agent.
|
||||
|
||||
Args:
|
||||
@@ -248,6 +248,8 @@ class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
|
||||
error: The error message if any.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
text: str
|
||||
type: str = "agent"
|
||||
agent_type: AgentType
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from .base import BaseEmbeddings
|
||||
from .endpoint_based import EndpointEmbeddings
|
||||
from .langchain_based import (
|
||||
AzureOpenAIEmbeddings,
|
||||
CohereEmbdeddings,
|
||||
@@ -8,6 +9,7 @@ from .langchain_based import (
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddings",
|
||||
"EndpointEmbeddings",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"CohereEmbdeddings",
|
||||
|
46
libs/kotaemon/kotaemon/embeddings/endpoint_based.py
Normal file
46
libs/kotaemon/kotaemon/embeddings/endpoint_based.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import requests
|
||||
|
||||
from kotaemon.base import Document, DocumentWithEmbedding
|
||||
|
||||
from .base import BaseEmbeddings
|
||||
|
||||
|
||||
class EndpointEmbeddings(BaseEmbeddings):
|
||||
"""
|
||||
An Embeddings component that uses an OpenAI API compatible endpoint.
|
||||
|
||||
Attributes:
|
||||
endpoint_url (str): The url of an OpenAI API compatible endpoint.
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
|
||||
def run(
|
||||
self, text: str | list[str] | Document | list[Document]
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
"""
|
||||
Generate embeddings from text Args:
|
||||
text (str | list[str] | Document | list[Document]): text to generate
|
||||
embeddings from
|
||||
Returns:
|
||||
list[DocumentWithEmbedding]: embeddings
|
||||
"""
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
outputs = []
|
||||
|
||||
for item in text:
|
||||
response = requests.post(
|
||||
self.endpoint_url, json={"input": str(item)}
|
||||
).json()
|
||||
outputs.append(
|
||||
DocumentWithEmbedding(
|
||||
text=str(item),
|
||||
embedding=response["data"][0]["embedding"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
@@ -108,6 +108,9 @@ class CitationPipeline(BaseComponent):
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if not llm_output.messages:
|
||||
return None
|
||||
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
@@ -126,6 +129,9 @@ class CitationPipeline(BaseComponent):
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if not llm_output.messages:
|
||||
return None
|
||||
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
|
@@ -2,7 +2,7 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
|
||||
|
||||
from .base import BaseLLM
|
||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||
from .chats import AzureChatOpenAI, ChatLLM, LlamaCppChat
|
||||
from .chats import AzureChatOpenAI, ChatLLM, EndpointChatLLM, LlamaCppChat
|
||||
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
||||
from .cot import ManualSequentialChainOfThought, Thought
|
||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||
@@ -12,6 +12,7 @@ __all__ = [
|
||||
"BaseLLM",
|
||||
# chat-specific components
|
||||
"ChatLLM",
|
||||
"EndpointChatLLM",
|
||||
"BaseMessage",
|
||||
"HumanMessage",
|
||||
"AIMessage",
|
||||
|
@@ -1,5 +1,12 @@
|
||||
from .base import ChatLLM
|
||||
from .endpoint_based import EndpointChatLLM
|
||||
from .langchain_based import AzureChatOpenAI, LCChatMixin
|
||||
from .llamacpp import LlamaCppChat
|
||||
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin", "LlamaCppChat"]
|
||||
__all__ = [
|
||||
"ChatLLM",
|
||||
"EndpointChatLLM",
|
||||
"AzureChatOpenAI",
|
||||
"LCChatMixin",
|
||||
"LlamaCppChat",
|
||||
]
|
||||
|
85
libs/kotaemon/kotaemon/llms/chats/endpoint_based.py
Normal file
85
libs/kotaemon/kotaemon/llms/chats/endpoint_based.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import requests
|
||||
|
||||
from kotaemon.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
LLMInterface,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
from .base import ChatLLM
|
||||
|
||||
|
||||
class EndpointChatLLM(ChatLLM):
|
||||
"""
|
||||
A ChatLLM that uses an endpoint to generate responses. This expects an OpenAI API
|
||||
compatible endpoint.
|
||||
|
||||
Attributes:
|
||||
endpoint_url (str): The url of a OpenAI API compatible endpoint.
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
|
||||
def run(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""
|
||||
Generate response from messages
|
||||
Args:
|
||||
messages (str | BaseMessage | list[BaseMessage]): history of messages to
|
||||
generate response from
|
||||
**kwargs: additional arguments to pass to the OpenAI API
|
||||
Returns:
|
||||
LLMInterface: generated response
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
input_ = [HumanMessage(content=messages)]
|
||||
elif isinstance(messages, BaseMessage):
|
||||
input_ = [messages]
|
||||
else:
|
||||
input_ = messages
|
||||
|
||||
def decide_role(message: BaseMessage):
|
||||
if isinstance(message, SystemMessage):
|
||||
return "system"
|
||||
elif isinstance(message, AIMessage):
|
||||
return "assistant"
|
||||
else:
|
||||
return "user"
|
||||
|
||||
request_json = {
|
||||
"messages": [{"content": m.text, "role": decide_role(m)} for m in input_]
|
||||
}
|
||||
|
||||
response = requests.post(self.endpoint_url, json=request_json).json()
|
||||
|
||||
content = ""
|
||||
candidates = []
|
||||
if response["choices"]:
|
||||
candidates = [
|
||||
each["message"]["content"]
|
||||
for each in response["choices"]
|
||||
if each["message"]["content"]
|
||||
]
|
||||
content = candidates[0]
|
||||
|
||||
return LLMInterface(
|
||||
content=content,
|
||||
candidates=candidates,
|
||||
completion_tokens=response["usage"]["completion_tokens"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""Same as run"""
|
||||
return self.run(messages, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
return self.invoke(messages, **kwargs)
|
@@ -12,7 +12,7 @@ user_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
COHERE_API_KEY = config("COHERE_API_KEY", default="")
|
||||
KH_MODE = "dev"
|
||||
KH_FEATURE_USER_MANAGEMENT = True
|
||||
KH_FEATURE_USER_MANAGEMENT = False
|
||||
KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
|
||||
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
|
||||
)
|
||||
@@ -21,6 +21,8 @@ KH_FEATURE_USER_MANAGEMENT_PASSWORD = str(
|
||||
)
|
||||
KH_ENABLE_ALEMBIC = False
|
||||
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
|
||||
KH_FILESTORAGE_PATH = str(user_cache_dir / "files")
|
||||
|
||||
KH_DOCSTORE = {
|
||||
"__type__": "kotaemon.storages.SimpleFileDocumentStore",
|
||||
"path": str(user_cache_dir / "docstore"),
|
||||
@@ -29,51 +31,68 @@ KH_VECTORSTORE = {
|
||||
"__type__": "kotaemon.storages.ChromaVectorStore",
|
||||
"path": str(user_cache_dir / "vectorstore"),
|
||||
}
|
||||
KH_FILESTORAGE_PATH = str(user_cache_dir / "files")
|
||||
KH_LLMS = {
|
||||
"gpt4": {
|
||||
# example for using Azure OpenAI, the config variables can set as environment
|
||||
# variables or in the .env file
|
||||
# "gpt4": {
|
||||
# "def": {
|
||||
# "__type__": "kotaemon.llms.AzureChatOpenAI",
|
||||
# "temperature": 0,
|
||||
# "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
# "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
# "openai_api_version": config("OPENAI_API_VERSION", default=""),
|
||||
# "deployment_name": "<your deployment name>",
|
||||
# "stream": True,
|
||||
# },
|
||||
# "accuracy": 10,
|
||||
# "cost": 10,
|
||||
# "default": False,
|
||||
# },
|
||||
# "gpt35": {
|
||||
# "def": {
|
||||
# "__type__": "kotaemon.llms.AzureChatOpenAI",
|
||||
# "temperature": 0,
|
||||
# "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
# "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
# "openai_api_version": config("OPENAI_API_VERSION", default=""),
|
||||
# "deployment_name": "<your deployment name>",
|
||||
# "request_timeout": 10,
|
||||
# "stream": False,
|
||||
# },
|
||||
# "accuracy": 5,
|
||||
# "cost": 5,
|
||||
# "default": False,
|
||||
# },
|
||||
"local": {
|
||||
"def": {
|
||||
"__type__": "kotaemon.llms.AzureChatOpenAI",
|
||||
"temperature": 0,
|
||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
"openai_api_version": config("OPENAI_API_VERSION", default=""),
|
||||
"deployment_name": "dummy-q2",
|
||||
"stream": True,
|
||||
"__type__": "kotaemon.llms.EndpointChatLLM",
|
||||
"endpoint_url": "http://localhost:31415/v1/chat/completions",
|
||||
},
|
||||
"accuracy": 10,
|
||||
"cost": 10,
|
||||
"default": False,
|
||||
},
|
||||
"gpt35": {
|
||||
"def": {
|
||||
"__type__": "kotaemon.llms.AzureChatOpenAI",
|
||||
"temperature": 0,
|
||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
"openai_api_version": config("OPENAI_API_VERSION", default=""),
|
||||
"deployment_name": "dummy-q2",
|
||||
"request_timeout": 10,
|
||||
"stream": False,
|
||||
},
|
||||
"accuracy": 5,
|
||||
"cost": 5,
|
||||
"default": True,
|
||||
},
|
||||
}
|
||||
KH_EMBEDDINGS = {
|
||||
"ada": {
|
||||
# example for using Azure OpenAI, the config variables can set as environment
|
||||
# variables or in the .env file
|
||||
# "ada": {
|
||||
# "def": {
|
||||
# "__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
|
||||
# "model": "text-embedding-ada-002",
|
||||
# "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
# "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
# "deployment": "<your deployment name>",
|
||||
# "chunk_size": 16,
|
||||
# },
|
||||
# "accuracy": 5,
|
||||
# "cost": 5,
|
||||
# "default": True,
|
||||
# },
|
||||
"local": {
|
||||
"def": {
|
||||
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
|
||||
"model": "text-embedding-ada-002",
|
||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
"deployment": "dummy-q2-text-embedding",
|
||||
"chunk_size": 16,
|
||||
"__type__": "kotaemon.embeddings.EndpointEmbeddings",
|
||||
"endpoint_url": "http://localhost:31415/v1/embeddings",
|
||||
},
|
||||
"accuracy": 5,
|
||||
"cost": 5,
|
||||
"default": True,
|
||||
"default": False,
|
||||
},
|
||||
}
|
||||
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
|
||||
|
@@ -118,7 +118,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||
|
||||
# rerank
|
||||
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
|
||||
if self.get_from_path("reranker"):
|
||||
if docs and self.get_from_path("reranker"):
|
||||
docs = self.reranker(docs, query=text)
|
||||
|
||||
if not self.get_extra_table:
|
||||
|
@@ -200,24 +200,37 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
lang=self.lang,
|
||||
)
|
||||
|
||||
citation_task = asyncio.create_task(
|
||||
self.citation_pipeline.ainvoke(context=evidence, question=question)
|
||||
)
|
||||
print("Citation task created")
|
||||
if evidence:
|
||||
citation_task = asyncio.create_task(
|
||||
self.citation_pipeline.ainvoke(context=evidence, question=question)
|
||||
)
|
||||
print("Citation task created")
|
||||
|
||||
messages = []
|
||||
if self.system_prompt:
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
|
||||
output = ""
|
||||
for text in self.llm.stream(messages):
|
||||
output += text.text
|
||||
self.report_output({"output": text.text})
|
||||
await asyncio.sleep(0)
|
||||
try:
|
||||
# try streaming first
|
||||
print("Trying LLM streaming")
|
||||
for text in self.llm.stream(messages):
|
||||
output += text.text
|
||||
self.report_output({"output": text.text})
|
||||
await asyncio.sleep(0)
|
||||
except NotImplementedError:
|
||||
print("Streaming is not supported, falling back to normal processing")
|
||||
output = self.llm(messages).text
|
||||
self.report_output({"output": output})
|
||||
|
||||
# retrieve the citation
|
||||
print("Waiting for citation task")
|
||||
citation = await citation_task
|
||||
if evidence:
|
||||
citation = await citation_task
|
||||
else:
|
||||
citation = None
|
||||
|
||||
answer = Document(text=output, metadata={"citation": citation})
|
||||
|
||||
return answer
|
||||
|
@@ -2,4 +2,4 @@ from ktem.main import App
|
||||
|
||||
app = App()
|
||||
demo = app.make()
|
||||
demo.queue().launch(favicon_path=app._favicon)
|
||||
demo.queue().launch(favicon_path=app._favicon, inbrowser=True)
|
||||
|
Reference in New Issue
Block a user