Utilize llama.cpp for both completion and chat models (#141)
This commit is contained in:
parent
a86c727869
commit
767aaaa1ef
|
@ -12,6 +12,7 @@ repos:
|
|||
args: ["--allow-missing-credentials"]
|
||||
- id: detect-private-key
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=750"]
|
||||
- id: debug-statements
|
||||
- repo: https://github.com/ambv/black
|
||||
rev: 22.3.0
|
||||
|
|
|
@ -2,8 +2,8 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
|
|||
|
||||
from .base import BaseLLM
|
||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||
from .chats import AzureChatOpenAI, ChatLLM
|
||||
from .completions import LLM, AzureOpenAI, OpenAI
|
||||
from .chats import AzureChatOpenAI, ChatLLM, LlamaCppChat
|
||||
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
||||
from .cot import ManualSequentialChainOfThought, Thought
|
||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||
from .prompts import BasePromptComponent, PromptTemplate
|
||||
|
@ -17,10 +17,12 @@ __all__ = [
|
|||
"AIMessage",
|
||||
"SystemMessage",
|
||||
"AzureChatOpenAI",
|
||||
"LlamaCppChat",
|
||||
# completion-specific components
|
||||
"LLM",
|
||||
"OpenAI",
|
||||
"AzureOpenAI",
|
||||
"LlamaCpp",
|
||||
# prompt-specific components
|
||||
"BasePromptComponent",
|
||||
"PromptTemplate",
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .base import ChatLLM
|
||||
from .langchain_based import AzureChatOpenAI, LCChatMixin
|
||||
from .llamacpp import LlamaCppChat
|
||||
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin"]
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin", "LlamaCppChat"]
|
||||
|
|
93
libs/kotaemon/kotaemon/llms/chats/llamacpp.py
Normal file
93
libs/kotaemon/kotaemon/llms/chats/llamacpp.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
|
||||
|
||||
from .base import ChatLLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp import CreateChatCompletionResponse as CCCR
|
||||
from llama_cpp import Llama
|
||||
|
||||
|
||||
class LlamaCppChat(ChatLLM):
|
||||
"""Wrapper around the llama-cpp-python's Llama model"""
|
||||
|
||||
model_path: Optional[str] = None
|
||||
chat_format: Optional[str] = None
|
||||
lora_base: Optional[str] = None
|
||||
n_ctx: int = 512
|
||||
n_gpu_layers: int = 0
|
||||
use_mmap: bool = True
|
||||
vocab_only: bool = False
|
||||
|
||||
_role_mapper: dict[str, str] = {
|
||||
"human": "user",
|
||||
"system": "system",
|
||||
"ai": "assistant",
|
||||
}
|
||||
|
||||
@Param.auto()
|
||||
def client_object(self) -> "Llama":
|
||||
"""Get the llama-cpp-python client object"""
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"llama-cpp-python is not installed. "
|
||||
"Please install it using `pip install llama-cpp-python`"
|
||||
)
|
||||
|
||||
errors = []
|
||||
if not self.model_path:
|
||||
errors.append("- `model_path` is required to load the model")
|
||||
|
||||
if not self.chat_format:
|
||||
errors.append(
|
||||
"- `chat_format` is required to know how to format the chat messages. "
|
||||
"Please refer to llama_cpp.llama_chat_format for a list of supported "
|
||||
"formats."
|
||||
)
|
||||
if errors:
|
||||
raise ValueError("\n".join(errors))
|
||||
|
||||
return Llama(
|
||||
model_path=cast(str, self.model_path),
|
||||
chat_format=self.chat_format,
|
||||
lora_base=self.lora_base,
|
||||
n_ctx=self.n_ctx,
|
||||
n_gpu_layers=self.n_gpu_layers,
|
||||
use_mmap=self.use_mmap,
|
||||
vocab_only=self.vocab_only,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
input_: list[BaseMessage] = []
|
||||
|
||||
if isinstance(messages, str):
|
||||
input_ = [HumanMessage(content=messages)]
|
||||
elif isinstance(messages, BaseMessage):
|
||||
input_ = [messages]
|
||||
else:
|
||||
input_ = messages
|
||||
|
||||
pred: "CCCR" = self.client_object.create_chat_completion(
|
||||
messages=[
|
||||
{"role": self._role_mapper[each.type], "content": each.content}
|
||||
for each in input_
|
||||
], # type: ignore
|
||||
stream=False,
|
||||
)
|
||||
|
||||
return LLMInterface(
|
||||
content=pred["choices"][0]["message"]["content"] if pred["choices"] else "",
|
||||
candidates=[
|
||||
c["message"]["content"]
|
||||
for c in pred["choices"]
|
||||
if c["message"]["content"]
|
||||
],
|
||||
completion_tokens=pred["usage"]["completion_tokens"],
|
||||
total_tokens=pred["usage"]["total_tokens"],
|
||||
prompt_tokens=pred["usage"]["prompt_tokens"],
|
||||
)
|
|
@ -1,4 +1,4 @@
|
|||
from .base import LLM
|
||||
from .langchain_based import AzureOpenAI, LCCompletionMixin, OpenAI
|
||||
from .langchain_based import AzureOpenAI, LCCompletionMixin, LlamaCpp, OpenAI
|
||||
|
||||
__all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin"]
|
||||
__all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin", "LlamaCpp"]
|
||||
|
|
|
@ -195,3 +195,33 @@ class AzureOpenAI(LCCompletionMixin, LLM):
|
|||
from langchain.llms import AzureOpenAI
|
||||
|
||||
return AzureOpenAI
|
||||
|
||||
|
||||
class LlamaCpp(LCCompletionMixin, LLM):
|
||||
"""Wrapper around Langchain's LlamaCpp class, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
lora_base: Optional[str] = None,
|
||||
n_ctx: int = 512,
|
||||
n_gpu_layers: Optional[int] = None,
|
||||
use_mmap: bool = True,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
lora_base=lora_base,
|
||||
n_ctx=n_ctx,
|
||||
n_gpu_layers=n_gpu_layers,
|
||||
use_mmap=use_mmap,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_community.llms import LlamaCpp
|
||||
except ImportError:
|
||||
from langchain.llms import LlamaCpp
|
||||
|
||||
return LlamaCpp
|
||||
|
|
|
@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
|
|||
# metadata and dependencies
|
||||
[project]
|
||||
name = "kotaemon"
|
||||
version = "0.3.7"
|
||||
version = "0.3.8"
|
||||
requires-python = ">= 3.10"
|
||||
description = "Kotaemon core library for AI development."
|
||||
dependencies = [
|
||||
|
@ -64,6 +64,7 @@ dev = [
|
|||
"elasticsearch",
|
||||
"pypdf",
|
||||
"html2text",
|
||||
"llama-cpp-python",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
BIN
libs/kotaemon/tests/resources/ggml-vocab-llama.gguf
Normal file
BIN
libs/kotaemon/tests/resources/ggml-vocab-llama.gguf
Normal file
Binary file not shown.
|
@ -1,12 +1,15 @@
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from kotaemon.base.schema import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
LLMInterface,
|
||||
SystemMessage,
|
||||
)
|
||||
from kotaemon.llms import AzureChatOpenAI
|
||||
from kotaemon.llms import AzureChatOpenAI, LlamaCppChat
|
||||
|
||||
try:
|
||||
from langchain_openai import AzureChatOpenAI as AzureChatOpenAILC
|
||||
|
@ -76,3 +79,23 @@ def test_azureopenai_model(openai_completion):
|
|||
output, LLMInterface
|
||||
), "Output for single text is not LLMInterface"
|
||||
openai_completion.assert_called()
|
||||
|
||||
|
||||
def test_llamacpp_chat():
|
||||
from llama_cpp import Llama
|
||||
|
||||
dir_path = Path(__file__).parent / "resources" / "ggml-vocab-llama.gguf"
|
||||
|
||||
# test initialization
|
||||
model = LlamaCppChat(model_path=str(dir_path), chat_format="llama", vocab_only=True)
|
||||
assert isinstance(model.client_object, Llama), "Error initializing llama_cpp.Llama"
|
||||
|
||||
# test error if model_path is omitted
|
||||
with pytest.raises(ValueError):
|
||||
model = LlamaCppChat(chat_format="llama", vocab_only=True)
|
||||
model.client_object
|
||||
|
||||
# test error if chat_format is omitted
|
||||
with pytest.raises(ValueError):
|
||||
model = LlamaCppChat(model_path=str(dir_path), vocab_only=True)
|
||||
model.client_object
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from kotaemon.base.schema import LLMInterface
|
||||
from kotaemon.llms import AzureOpenAI, OpenAI
|
||||
from kotaemon.llms import AzureOpenAI, LlamaCpp, OpenAI
|
||||
|
||||
try:
|
||||
from langchain_openai import AzureOpenAI as AzureOpenAILC
|
||||
|
@ -76,3 +77,11 @@ def test_openai_model(openai_completion):
|
|||
assert isinstance(
|
||||
output, LLMInterface
|
||||
), "Output for single text is not LLMInterface"
|
||||
|
||||
|
||||
def test_llamacpp_model():
|
||||
weight_path = Path(__file__).parent / "resources" / "ggml-vocab-llama.gguf"
|
||||
|
||||
# test initialization
|
||||
model = LlamaCpp(model_path=str(weight_path), vocab_only=True)
|
||||
assert isinstance(model._obj, model._get_lc_class())
|
||||
|
|
Loading…
Reference in New Issue
Block a user