[AUR-390] Add prompt template and prompt component (#24)

* Export pipeline to config

* Export the input to config

* Preliminary creating UI dynamically

* Add test for config export

* Try out prompt UI

* Add example projects

* Fix test errors

* Standardize interface for retrieval

* Finalize the UI demo

* Update README.md

* Update README

* Refactor according to main

* Fix typing issue

* Add openai key to git-secret

* Add prompt template and prompt component

* Update test

* update tests

* revert docstring

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
Co-authored-by: Nguyen Trung Duc (john) <john@cinnamon.is>
This commit is contained in:
ian_Cin 2023-09-25 14:38:22 +07:00 committed by GitHub
parent c6dd01e820
commit 08b6e5d3fb
8 changed files with 378 additions and 2 deletions

View File

@ -22,6 +22,9 @@ class Document(BaseDocument):
text = self.text
return HaystackDocument(content=text, meta=metadata)
def __str__(self):
return self.text
class RetrievedDocument(Document):
"""Subclass of Document with retrieval-related information

View File

@ -6,7 +6,8 @@ from kotaemon.documents.base import Document
class RegexExtractor(BaseComponent):
"""Simple class for extracting text from a document using a regex pattern.
"""
Simple class for extracting text from a document using a regex pattern.
Args:
pattern (str): The regex pattern to use.

View File

184
knowledgehub/prompt/base.py Normal file
View File

@ -0,0 +1,184 @@
from typing import Union
from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document
from kotaemon.prompt.template import PromptTemplate
class BasePrompt(BaseComponent):
"""
Base class for prompt components.
Args:
template (PromptTemplate): The prompt template.
**kwargs: Any additional keyword arguments that will be used to populate the
given template.
"""
def __check_redundant_kwargs(self, **kwargs):
"""
Check for redundant keyword arguments.
Parameters:
**kwargs (dict): A dictionary of keyword arguments.
Raises:
ValueError: If any keys provided are not in the template.
Returns:
None
"""
provided_keys = set(kwargs.keys())
expected_keys = self.template.placeholders
redundant_keys = provided_keys - expected_keys
if redundant_keys:
raise ValueError(f"\nKeys provided but not in template: {redundant_keys}")
def __check_unset_placeholders(self):
"""
Check if all the placeholders in the template are set.
This function checks if all the expected placeholders in the template are set as
attributes of the object. If any placeholders are missing, a `ValueError`
is raised with the names of the missing keys.
Parameters:
None
Returns:
None
"""
expected_keys = self.template.placeholders
missing_keys = []
for key in expected_keys:
if key not in self.__dict__:
missing_keys.append(key)
if missing_keys:
raise ValueError(f"\nMissing keys in template: {missing_keys}")
def __validate_value_type(self, **kwargs):
"""
Validates the value types of the given keyword arguments.
Parameters:
**kwargs (dict): A dictionary of keyword arguments to be validated.
Raises:
ValueError: If any of the values in the kwargs dictionary have an
unsupported type.
Returns:
None
"""
type_error = []
for k, v in kwargs.items():
if not isinstance(v, (str, int, Document, BaseComponent)):
if isinstance(v, int):
kwargs[k] = str(v)
type_error.append((k, type(v)))
if type_error:
raise ValueError(
"Type of values must be either int, str, Document, BaseComponent, "
f"found unsupported type for (key, type): {type_error}"
)
def __set(self, **kwargs):
"""
Set the values of the attributes in the object based on the provided keyword
arguments.
Args:
kwargs (dict): A dictionary with the attribute names as keys and the new
values as values.
Returns:
None
"""
self.__check_redundant_kwargs(**kwargs)
self.__validate_value_type(**kwargs)
self.__dict__.update(kwargs)
def __prepare_value(self):
"""
Generate a dictionary of keyword arguments based on the template's placeholders
and the current instance's attributes.
Returns:
dict: A dictionary of keyword arguments.
"""
kwargs = {}
for k in self.template.placeholders:
v = getattr(self, k)
if isinstance(v, (int, Document)):
v = str(v)
elif isinstance(v, BaseComponent):
v = str(v())
kwargs[k] = v
return kwargs
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def set(self, **kwargs):
"""
Similar to `__set` but for external use.
Set the values of the attributes in the object based on the provided keyword
arguments.
Args:
kwargs (dict): A dictionary with the attribute names as keys and the new
values as values.
Returns:
None
"""
self.__set(**kwargs)
def run(self, **kwargs):
"""
Run the function with the given keyword arguments.
Args:
**kwargs: The keyword arguments to pass to the function.
Returns:
The result of calling the `populate` method of the `template` object
with the given keyword arguments.
"""
self.__set(**kwargs)
self.__check_unset_placeholders()
prepared_kwargs = self.__prepare_value()
return self.template.populate(**prepared_kwargs)
def run_raw(self, *args, **kwargs):
pass
def run_batch_raw(self, *args, **kwargs):
pass
def run_document(self, *args, **kwargs):
pass
def run_batch_document(self, *args, **kwargs):
pass
def is_document(self, *args, **kwargs):
pass
def is_batch(self, *args, **kwargs):
pass

View File

@ -0,0 +1,68 @@
import re
from typing import Set
class PromptTemplate:
"""
Base class for prompt templates.
"""
@staticmethod
def extract_placeholders(template: str) -> Set[str]:
"""
Extracts placeholders from a template string.
Args:
template (str): The template string to extract placeholders from.
Returns:
set[str]: A set of placeholder names found in the template string.
"""
placeholder_regex = r"{([a-zA-Z_][a-zA-Z0-9_]*)}"
placeholders = set()
for item in re.findall(placeholder_regex, template):
if item.isidentifier():
placeholders.add(item)
return placeholders
def __init__(self, template: str):
self.placeholders = self.extract_placeholders(template)
self.template = template
def populate(self, **kwargs):
"""
Populate the template with the given keyword arguments.
Args:
**kwargs: The keyword arguments to populate the template.
Each keyword corresponds to a placeholder in the template.
Returns:
str: The populated template.
Raises:
ValueError: If an unknown placeholder is provided.
"""
prompt = self.template
for placeholder, value in kwargs.items():
if placeholder not in self.placeholders:
raise ValueError(f"Unknown placeholder: {placeholder}")
prompt = prompt.replace(f"{{{placeholder}}}", value)
return prompt
def __add__(self, other):
"""
Create a new PromptTemplate object by concatenating the template of the current
object with the template of another PromptTemplate object.
Parameters:
other (PromptTemplate): Another PromptTemplate object.
Returns:
PromptTemplate: A new PromptTemplate object with the concatenated templates.
"""
return PromptTemplate(self.template + "\n" + other.template)

67
tests/test_prompt.py Normal file
View File

@ -0,0 +1,67 @@
import pytest
from kotaemon.documents.base import Document
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePrompt
from kotaemon.prompt.template import PromptTemplate
def test_set_attributes():
template = PromptTemplate("str = {s}, int = {i}, doc = {doc}, comp = {comp}")
doc = Document(text="Helloo, Alice!")
comp = RegexExtractor(
pattern=r"\d+", output_map={"1": "One", "2": "Two", "3": "Three"}
)
comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True)
prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp)
assert prompt.s == "Alice"
assert prompt.i == 30
assert prompt.doc == doc
assert prompt.comp == comp
def test_check_redundant_kwargs():
template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template, name="Alice")
with pytest.raises(ValueError):
prompt._BasePrompt__check_redundant_kwargs(name="Alice", age=30)
def test_check_unset_placeholders():
template = PromptTemplate("Hello, {name}! I'm {age} years old.")
prompt = BasePrompt(template, name="Alice")
with pytest.raises(ValueError):
prompt._BasePrompt__check_unset_placeholders()
def test_validate_value_type():
template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template)
with pytest.raises(ValueError):
prompt._BasePrompt__validate_value_type(name={})
def test_run():
template = PromptTemplate("str = {s}, int = {i}, doc = {doc}, comp = {comp}")
doc = Document(text="Helloo, Alice!")
comp = RegexExtractor(
pattern=r"\d+", output_map={"1": "One", "2": "Two", "3": "Three"}
)
comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True)
prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp)
result = prompt()
assert (
result
== "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']"
)
def test_set_method():
template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template)
prompt.set(name="Alice")
assert prompt.name == "Alice"

53
tests/test_template.py Normal file
View File

@ -0,0 +1,53 @@
import pytest
from kotaemon.prompt.template import PromptTemplate
def test_prompt_template_creation():
# Test case 1: Ensure the PromptTemplate object is created correctly
template_string = "This is a template"
template = PromptTemplate(template_string)
assert template.template == template_string
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
assert template.template == template_string
assert template.placeholders == {"name", "day"}
def test_prompt_template_addition():
# Test case 2: Ensure the __add__ method concatenates the templates correctly
template1 = PromptTemplate("Hello, ")
template2 = PromptTemplate("world!")
result = template1 + template2
assert result.template == "Hello, \nworld!"
template1 = PromptTemplate("Hello, {name}!")
template2 = PromptTemplate("Today is {day}.")
result = template1 + template2
assert result.template == "Hello, {name}!\nToday is {day}."
def test_prompt_template_extract_placeholders():
# Test case 3: Ensure the extract_placeholders method extracts placeholders
# correctly
template_string = "Hello, {name}! Today is {day}."
result = PromptTemplate.extract_placeholders(template_string)
assert result == {"name", "day"}
def test_prompt_template_populate():
# Test case 4: Ensure the populate method populates the template correctly
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
result = template.populate(name="John", day="Monday")
assert result == "Hello, John! Today is Monday."
def test_prompt_template_unknown_placeholder():
# Test case 5: Ensure the populate method raises an exception for unknown
# placeholders
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
with pytest.raises(ValueError):
template.populate(name="John", month="January")

View File

@ -56,7 +56,7 @@ class TestChromaVectorStore:
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
_, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1)
assert sim == [0.0]
assert sim == [1.0]
assert out_ids == ["a"]
_, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)