Feat/local endpoint llm (#148)

* serve local model in a different process from the app
---------

Co-authored-by: albert <albert@cinnamon.is>
Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2024-03-15 16:17:33 +07:00
committed by GitHub
parent 2950e6ed02
commit df12dec732
20 changed files with 675 additions and 79 deletions

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
from pydantic import Extra
from pydantic import ConfigDict
from kotaemon.base import LLMInterface
@@ -238,7 +238,7 @@ class AgentFinish(NamedTuple):
log: str
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
class AgentOutput(LLMInterface):
"""Output from an agent.
Args:
@@ -248,6 +248,8 @@ class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
error: The error message if any.
"""
model_config = ConfigDict(extra="allow")
text: str
type: str = "agent"
agent_type: AgentType

View File

@@ -1,4 +1,5 @@
from .base import BaseEmbeddings
from .endpoint_based import EndpointEmbeddings
from .langchain_based import (
AzureOpenAIEmbeddings,
CohereEmbdeddings,
@@ -8,6 +9,7 @@ from .langchain_based import (
__all__ = [
"BaseEmbeddings",
"EndpointEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"CohereEmbdeddings",

View File

@@ -0,0 +1,46 @@
import requests
from kotaemon.base import Document, DocumentWithEmbedding
from .base import BaseEmbeddings
class EndpointEmbeddings(BaseEmbeddings):
"""
An Embeddings component that uses an OpenAI API compatible endpoint.
Attributes:
endpoint_url (str): The url of an OpenAI API compatible endpoint.
"""
endpoint_url: str
def run(
self, text: str | list[str] | Document | list[Document]
) -> list[DocumentWithEmbedding]:
"""
Generate embeddings from text Args:
text (str | list[str] | Document | list[Document]): text to generate
embeddings from
Returns:
list[DocumentWithEmbedding]: embeddings
"""
if not isinstance(text, list):
text = [text]
outputs = []
for item in text:
response = requests.post(
self.endpoint_url, json={"input": str(item)}
).json()
outputs.append(
DocumentWithEmbedding(
text=str(item),
embedding=response["data"][0]["embedding"],
total_tokens=response["usage"]["total_tokens"],
prompt_tokens=response["usage"]["prompt_tokens"],
)
)
return outputs

View File

@@ -108,6 +108,9 @@ class CitationPipeline(BaseComponent):
print(e)
return None
if not llm_output.messages:
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
@@ -126,6 +129,9 @@ class CitationPipeline(BaseComponent):
print(e)
return None
if not llm_output.messages:
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]

View File

@@ -2,7 +2,7 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
from .base import BaseLLM
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .chats import AzureChatOpenAI, ChatLLM, LlamaCppChat
from .chats import AzureChatOpenAI, ChatLLM, EndpointChatLLM, LlamaCppChat
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
from .cot import ManualSequentialChainOfThought, Thought
from .linear import GatedLinearPipeline, SimpleLinearPipeline
@@ -12,6 +12,7 @@ __all__ = [
"BaseLLM",
# chat-specific components
"ChatLLM",
"EndpointChatLLM",
"BaseMessage",
"HumanMessage",
"AIMessage",

View File

@@ -1,5 +1,12 @@
from .base import ChatLLM
from .endpoint_based import EndpointChatLLM
from .langchain_based import AzureChatOpenAI, LCChatMixin
from .llamacpp import LlamaCppChat
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin", "LlamaCppChat"]
__all__ = [
"ChatLLM",
"EndpointChatLLM",
"AzureChatOpenAI",
"LCChatMixin",
"LlamaCppChat",
]

View File

@@ -0,0 +1,85 @@
import requests
from kotaemon.base import (
AIMessage,
BaseMessage,
HumanMessage,
LLMInterface,
SystemMessage,
)
from .base import ChatLLM
class EndpointChatLLM(ChatLLM):
"""
A ChatLLM that uses an endpoint to generate responses. This expects an OpenAI API
compatible endpoint.
Attributes:
endpoint_url (str): The url of a OpenAI API compatible endpoint.
"""
endpoint_url: str
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""
Generate response from messages
Args:
messages (str | BaseMessage | list[BaseMessage]): history of messages to
generate response from
**kwargs: additional arguments to pass to the OpenAI API
Returns:
LLMInterface: generated response
"""
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
def decide_role(message: BaseMessage):
if isinstance(message, SystemMessage):
return "system"
elif isinstance(message, AIMessage):
return "assistant"
else:
return "user"
request_json = {
"messages": [{"content": m.text, "role": decide_role(m)} for m in input_]
}
response = requests.post(self.endpoint_url, json=request_json).json()
content = ""
candidates = []
if response["choices"]:
candidates = [
each["message"]["content"]
for each in response["choices"]
if each["message"]["content"]
]
content = candidates[0]
return LLMInterface(
content=content,
candidates=candidates,
completion_tokens=response["usage"]["completion_tokens"],
total_tokens=response["usage"]["total_tokens"],
prompt_tokens=response["usage"]["prompt_tokens"],
)
def invoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Same as run"""
return self.run(messages, **kwargs)
async def ainvoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
return self.invoke(messages, **kwargs)

View File

@@ -12,7 +12,7 @@ user_cache_dir.mkdir(parents=True, exist_ok=True)
COHERE_API_KEY = config("COHERE_API_KEY", default="")
KH_MODE = "dev"
KH_FEATURE_USER_MANAGEMENT = True
KH_FEATURE_USER_MANAGEMENT = False
KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
)
@@ -21,6 +21,8 @@ KH_FEATURE_USER_MANAGEMENT_PASSWORD = str(
)
KH_ENABLE_ALEMBIC = False
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
KH_FILESTORAGE_PATH = str(user_cache_dir / "files")
KH_DOCSTORE = {
"__type__": "kotaemon.storages.SimpleFileDocumentStore",
"path": str(user_cache_dir / "docstore"),
@@ -29,51 +31,68 @@ KH_VECTORSTORE = {
"__type__": "kotaemon.storages.ChromaVectorStore",
"path": str(user_cache_dir / "vectorstore"),
}
KH_FILESTORAGE_PATH = str(user_cache_dir / "files")
KH_LLMS = {
"gpt4": {
# example for using Azure OpenAI, the config variables can set as environment
# variables or in the .env file
# "gpt4": {
# "def": {
# "__type__": "kotaemon.llms.AzureChatOpenAI",
# "temperature": 0,
# "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
# "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
# "openai_api_version": config("OPENAI_API_VERSION", default=""),
# "deployment_name": "<your deployment name>",
# "stream": True,
# },
# "accuracy": 10,
# "cost": 10,
# "default": False,
# },
# "gpt35": {
# "def": {
# "__type__": "kotaemon.llms.AzureChatOpenAI",
# "temperature": 0,
# "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
# "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
# "openai_api_version": config("OPENAI_API_VERSION", default=""),
# "deployment_name": "<your deployment name>",
# "request_timeout": 10,
# "stream": False,
# },
# "accuracy": 5,
# "cost": 5,
# "default": False,
# },
"local": {
"def": {
"__type__": "kotaemon.llms.AzureChatOpenAI",
"temperature": 0,
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
"openai_api_version": config("OPENAI_API_VERSION", default=""),
"deployment_name": "dummy-q2",
"stream": True,
"__type__": "kotaemon.llms.EndpointChatLLM",
"endpoint_url": "http://localhost:31415/v1/chat/completions",
},
"accuracy": 10,
"cost": 10,
"default": False,
},
"gpt35": {
"def": {
"__type__": "kotaemon.llms.AzureChatOpenAI",
"temperature": 0,
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
"openai_api_version": config("OPENAI_API_VERSION", default=""),
"deployment_name": "dummy-q2",
"request_timeout": 10,
"stream": False,
},
"accuracy": 5,
"cost": 5,
"default": True,
},
}
KH_EMBEDDINGS = {
"ada": {
# example for using Azure OpenAI, the config variables can set as environment
# variables or in the .env file
# "ada": {
# "def": {
# "__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
# "model": "text-embedding-ada-002",
# "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
# "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
# "deployment": "<your deployment name>",
# "chunk_size": 16,
# },
# "accuracy": 5,
# "cost": 5,
# "default": True,
# },
"local": {
"def": {
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
"model": "text-embedding-ada-002",
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
"deployment": "dummy-q2-text-embedding",
"chunk_size": 16,
"__type__": "kotaemon.embeddings.EndpointEmbeddings",
"endpoint_url": "http://localhost:31415/v1/embeddings",
},
"accuracy": 5,
"cost": 5,
"default": True,
"default": False,
},
}
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]

View File

@@ -118,7 +118,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
# rerank
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
if self.get_from_path("reranker"):
if docs and self.get_from_path("reranker"):
docs = self.reranker(docs, query=text)
if not self.get_extra_table:

View File

@@ -200,24 +200,37 @@ class AnswerWithContextPipeline(BaseComponent):
lang=self.lang,
)
citation_task = asyncio.create_task(
self.citation_pipeline.ainvoke(context=evidence, question=question)
)
print("Citation task created")
if evidence:
citation_task = asyncio.create_task(
self.citation_pipeline.ainvoke(context=evidence, question=question)
)
print("Citation task created")
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt))
output = ""
for text in self.llm.stream(messages):
output += text.text
self.report_output({"output": text.text})
await asyncio.sleep(0)
try:
# try streaming first
print("Trying LLM streaming")
for text in self.llm.stream(messages):
output += text.text
self.report_output({"output": text.text})
await asyncio.sleep(0)
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
self.report_output({"output": output})
# retrieve the citation
print("Waiting for citation task")
citation = await citation_task
if evidence:
citation = await citation_task
else:
citation = None
answer = Document(text=output, metadata={"citation": citation})
return answer

View File

@@ -2,4 +2,4 @@ from ktem.main import App
app = App()
demo = app.make()
demo.queue().launch(favicon_path=app._favicon)
demo.queue().launch(favicon_path=app._favicon, inbrowser=True)