diff --git a/README.md b/README.md index 3291e9a..a71616a 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which internally uses `gpg` to encrypt and decrypt secret files. -This repo uses `python-dotenv` to manage credentials stored as enviroment variable. +This repo uses `python-dotenv` to manage credentials stored as environment variable. Please note that the use of `python-dotenv` and credentials are for development purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`. diff --git a/knowledgehub/prompt/base.py b/knowledgehub/prompt/base.py index 6570cb8..caee3b0 100644 --- a/knowledgehub/prompt/base.py +++ b/knowledgehub/prompt/base.py @@ -1,5 +1,4 @@ -import warnings -from typing import Union +from typing import Callable, Union from kotaemon.base import BaseComponent from kotaemon.documents.base import Document @@ -39,14 +38,7 @@ class BasePromptComponent(BaseComponent): Returns: None """ - provided_keys = set(kwargs.keys()) - expected_keys = self.template.placeholders - - redundant_keys = provided_keys - expected_keys - if redundant_keys: - warnings.warn( - f"Keys provided but not in template: {redundant_keys}", UserWarning - ) + self.template.check_redundant_kwargs(**kwargs) def __check_unset_placeholders(self): """ @@ -62,15 +54,7 @@ class BasePromptComponent(BaseComponent): Returns: None """ - expected_keys = self.template.placeholders - - missing_keys = [] - for key in expected_keys: - if key not in self.__dict__: - missing_keys.append(key) - - if missing_keys: - raise ValueError(f"\nMissing keys in template: {missing_keys}") + self.template.check_missing_kwargs(**self.__dict__) def __validate_value_type(self, **kwargs): """ @@ -88,14 +72,12 @@ class BasePromptComponent(BaseComponent): """ type_error = [] for k, v in kwargs.items(): - if not isinstance(v, (str, int, Document, BaseComponent)): - if isinstance(v, int): - kwargs[k] = str(v) + if not isinstance(v, (str, int, Document, Callable)): # type: ignore type_error.append((k, type(v))) if type_error: raise ValueError( - "Type of values must be either int, str, Document, BaseComponent, " + "Type of values must be either int, str, Document, Callable, " f"found unsupported type for (key, type): {type_error}" ) @@ -138,15 +120,18 @@ class BasePromptComponent(BaseComponent): kwargs = {} for k in self.template.placeholders: v = getattr(self, k) - if isinstance(v, BaseComponent): + + # if get a callable, execute to get its output + if isinstance(v, Callable): # type: ignore[arg-type] v = v() + if isinstance(v, list): v = str([__prepare(k, each) for each in v]) elif isinstance(v, (str, int, Document)): v = __prepare(k, v) else: raise ValueError( - f"Unsupported type {type(v)} for template value of key {k}" + f"Unsupported type {type(v)} for template value of key `{k}`" ) kwargs[k] = v diff --git a/knowledgehub/prompt/template.py b/knowledgehub/prompt/template.py index 03b79bc..46692ec 100644 --- a/knowledgehub/prompt/template.py +++ b/knowledgehub/prompt/template.py @@ -1,5 +1,5 @@ -import re -from typing import Set +import warnings +from string import Formatter class PromptTemplate: @@ -7,33 +7,74 @@ class PromptTemplate: Base class for prompt templates. """ - @staticmethod - def extract_placeholders(template: str) -> Set[str]: - """ - Extracts placeholders from a template string. - - Args: - template (str): The template string to extract placeholders from. - - Returns: - set[str]: A set of placeholder names found in the template string. - """ - placeholder_regex = r"{([a-zA-Z_][a-zA-Z0-9_]*)}" + def __init__(self, template: str, ignore_invalid=True): + template = template + formatter = Formatter() + parsed_template = list(formatter.parse(template)) placeholders = set() - for item in re.findall(placeholder_regex, template): - if item.isidentifier(): - placeholders.add(item) + for _, key, _, _ in parsed_template: + if key is None: + continue + if not key.isidentifier(): + if ignore_invalid: + warnings.warn(f"Ignore invalid placeholder: {key}.", UserWarning) + else: + raise ValueError( + "Placeholder name must be a valid Python identifier, found:" + f" {key}." + ) + placeholders.add(key) - return placeholders - - def __init__(self, template: str): - self.placeholders = self.extract_placeholders(template) self.template = template + self.placeholders = placeholders + self.__formatter = formatter + self.__parsed_template = parsed_template + + def check_missing_kwargs(self, **kwargs): + """ + Check if all the placeholders in the template are set. + + This function checks if all the expected placeholders in the template are set as + attributes of the object. If any placeholders are missing, a `ValueError` + is raised with the names of the missing keys. + + Parameters: + None + + Returns: + None + """ + missing_keys = self.placeholders.difference(kwargs.keys()) + if missing_keys: + raise ValueError(f"Missing keys in template: {','.join(missing_keys)}") + + def check_redundant_kwargs(self, **kwargs): + """ + Check if all the placeholders in the template are set. + + This function checks if all the expected placeholders in the template are set as + attributes of the object. If any placeholders are missing, a `ValueError` + is raised with the names of the missing keys. + + Parameters: + None + + Returns: + None + """ + provided_keys = set(kwargs.keys()) + redundant_keys = provided_keys - self.placeholders + + if redundant_keys: + warnings.warn( + f"Keys provided but not in template: {','.join(redundant_keys)}", + UserWarning, + ) def populate(self, **kwargs): """ - Populate the template with the given keyword arguments. + Strictly populate the template with the given keyword arguments. Args: **kwargs: The keyword arguments to populate the template. @@ -44,15 +85,46 @@ class PromptTemplate: Raises: ValueError: If an unknown placeholder is provided. - """ - prompt = self.template - for placeholder, value in kwargs.items(): - if placeholder not in self.placeholders: - raise ValueError(f"Unknown placeholder: {placeholder}") - prompt = prompt.replace(f"{{{placeholder}}}", value) + self.check_missing_kwargs(**kwargs) - return prompt + return self.partial_populate(**kwargs) + + def partial_populate(self, **kwargs): + """ + Partially populate the template with the given keyword arguments. + + Args: + **kwargs: The keyword arguments to populate the template. + Each keyword corresponds to a placeholder in the template. + + Returns: + str: The populated template. + """ + self.check_redundant_kwargs(**kwargs) + + prompt = [] + for literal_text, field_name, format_spec, conversion in self.__parsed_template: + prompt.append(literal_text) + + if field_name is None: + continue + + if field_name not in kwargs: + if conversion: + value = f"{{{field_name}}}!{conversion}:{format_spec}" + else: + value = f"{{{field_name}:{format_spec}}}" + else: + value = kwargs[field_name] + if conversion is not None: + value = self.__formatter.convert_field(value, conversion) + if format_spec is not None: + value = self.__formatter.format_field(value, format_spec) + + prompt.append(value) + + return "".join(prompt) def __add__(self, other): """ diff --git a/tests/test_prompt.py b/tests/test_prompt.py index a9e0505..f71d5fe 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -24,7 +24,7 @@ def test_set_attributes(): def test_check_redundant_kwargs(): template = PromptTemplate("Hello, {name}!") prompt = BasePromptComponent(template, name="Alice") - with pytest.warns(UserWarning, match="Keys provided but not in template: {'age'}"): + with pytest.warns(UserWarning, match="Keys provided but not in template: age"): prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30) diff --git a/tests/test_template.py b/tests/test_template.py index 96a3d74..c0e586c 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -4,7 +4,7 @@ from kotaemon.prompt.template import PromptTemplate def test_prompt_template_creation(): - # Test case 1: Ensure the PromptTemplate object is created correctly + # Ensure the PromptTemplate object is created correctly template_string = "This is a template" template = PromptTemplate(template_string) assert template.template == template_string @@ -15,8 +15,22 @@ def test_prompt_template_creation(): assert template.placeholders == {"name", "day"} +def test_prompt_template_creation_invalid_placeholder(): + # Ensure the PromptTemplate object handle invalid placeholder correctly + template_string = "Hello, {name}! Today is {0day}." + + with pytest.raises(ValueError): + PromptTemplate(template_string, ignore_invalid=False) + + with pytest.warns( + UserWarning, + match="Ignore invalid placeholder: 0day.", + ): + PromptTemplate(template_string, ignore_invalid=True) + + def test_prompt_template_addition(): - # Test case 2: Ensure the __add__ method concatenates the templates correctly + # Ensure the __add__ method concatenates the templates correctly template1 = PromptTemplate("Hello, ") template2 = PromptTemplate("world!") result = template1 + template2 @@ -29,25 +43,71 @@ def test_prompt_template_addition(): def test_prompt_template_extract_placeholders(): - # Test case 3: Ensure the extract_placeholders method extracts placeholders - # correctly + # Ensure the PromptTemplate correctly extracts placeholders template_string = "Hello, {name}! Today is {day}." - result = PromptTemplate.extract_placeholders(template_string) + result = PromptTemplate(template_string).placeholders assert result == {"name", "day"} def test_prompt_template_populate(): - # Test case 4: Ensure the populate method populates the template correctly + # Ensure the populate method populates the template correctly template_string = "Hello, {name}! Today is {day}." template = PromptTemplate(template_string) result = template.populate(name="John", day="Monday") assert result == "Hello, John! Today is Monday." -def test_prompt_template_unknown_placeholder(): - # Test case 5: Ensure the populate method raises an exception for unknown - # placeholders +def test_prompt_template_check_missing_kwargs(): + # Ensure the check_missing_kwargs and populate methods raise an exception for + # missing placeholders template_string = "Hello, {name}! Today is {day}." template = PromptTemplate(template_string) + kwargs = dict(name="John") + with pytest.raises(ValueError): - template.populate(name="John", month="January") + template.check_missing_kwargs(**kwargs) + + with pytest.raises(ValueError): + template.populate(**kwargs) + + +def test_prompt_template_check_redundant_kwargs(): + # Ensure the check_redundant_kwargs, partial_populate and populate methods warn for + # redundant placeholders + template_string = "Hello, {name}! Today is {day}." + template = PromptTemplate(template_string) + kwargs = dict(name="John", day="Monday", age="30") + + with pytest.warns(UserWarning, match="Keys provided but not in template: age"): + template.check_redundant_kwargs(**kwargs) + + with pytest.warns(UserWarning, match="Keys provided but not in template: age"): + template.partial_populate(**kwargs) + + with pytest.warns(UserWarning, match="Keys provided but not in template: age"): + template.populate(**kwargs) + + +def test_prompt_template_populate_complex_template(): + # Ensure the populate method produces the same results as the built-in str.format + # function + template_string = ( + "a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}" + ) + template = PromptTemplate(template_string) + kwargs = dict(a=1, b="two", c=3, d=4, e="á") + populated = template.populate(**kwargs) + expected = template_string.format(**kwargs) + assert populated == expected + + +def test_prompt_template_partial_populate(): + # Ensure the partial_populate method populates correctly + template_string = ( + "a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}" + ) + template = PromptTemplate(template_string) + kwargs = dict(a=1, b="two", d=4, e="á") + populated = template.partial_populate(**kwargs) + expected = "a = 1.00, b = two, c = {c:.1%}, d = 4., ascii of á = '\\xe1'" + assert populated == expected