[Feat] Add support for f-string syntax in PromptTemplate (#38)

* Add support for f-string syntax in PromptTemplate
This commit is contained in:
ian_Cin
2023-10-04 16:40:09 +07:00
committed by GitHub
parent 56bc41b673
commit 2638152054
5 changed files with 183 additions and 66 deletions

View File

@@ -1,5 +1,5 @@
import re
from typing import Set
import warnings
from string import Formatter
class PromptTemplate:
@@ -7,33 +7,74 @@ class PromptTemplate:
Base class for prompt templates.
"""
@staticmethod
def extract_placeholders(template: str) -> Set[str]:
"""
Extracts placeholders from a template string.
Args:
template (str): The template string to extract placeholders from.
Returns:
set[str]: A set of placeholder names found in the template string.
"""
placeholder_regex = r"{([a-zA-Z_][a-zA-Z0-9_]*)}"
def __init__(self, template: str, ignore_invalid=True):
template = template
formatter = Formatter()
parsed_template = list(formatter.parse(template))
placeholders = set()
for item in re.findall(placeholder_regex, template):
if item.isidentifier():
placeholders.add(item)
for _, key, _, _ in parsed_template:
if key is None:
continue
if not key.isidentifier():
if ignore_invalid:
warnings.warn(f"Ignore invalid placeholder: {key}.", UserWarning)
else:
raise ValueError(
"Placeholder name must be a valid Python identifier, found:"
f" {key}."
)
placeholders.add(key)
return placeholders
def __init__(self, template: str):
self.placeholders = self.extract_placeholders(template)
self.template = template
self.placeholders = placeholders
self.__formatter = formatter
self.__parsed_template = parsed_template
def check_missing_kwargs(self, **kwargs):
"""
Check if all the placeholders in the template are set.
This function checks if all the expected placeholders in the template are set as
attributes of the object. If any placeholders are missing, a `ValueError`
is raised with the names of the missing keys.
Parameters:
None
Returns:
None
"""
missing_keys = self.placeholders.difference(kwargs.keys())
if missing_keys:
raise ValueError(f"Missing keys in template: {','.join(missing_keys)}")
def check_redundant_kwargs(self, **kwargs):
"""
Check if all the placeholders in the template are set.
This function checks if all the expected placeholders in the template are set as
attributes of the object. If any placeholders are missing, a `ValueError`
is raised with the names of the missing keys.
Parameters:
None
Returns:
None
"""
provided_keys = set(kwargs.keys())
redundant_keys = provided_keys - self.placeholders
if redundant_keys:
warnings.warn(
f"Keys provided but not in template: {','.join(redundant_keys)}",
UserWarning,
)
def populate(self, **kwargs):
"""
Populate the template with the given keyword arguments.
Strictly populate the template with the given keyword arguments.
Args:
**kwargs: The keyword arguments to populate the template.
@@ -44,15 +85,46 @@ class PromptTemplate:
Raises:
ValueError: If an unknown placeholder is provided.
"""
prompt = self.template
for placeholder, value in kwargs.items():
if placeholder not in self.placeholders:
raise ValueError(f"Unknown placeholder: {placeholder}")
prompt = prompt.replace(f"{{{placeholder}}}", value)
self.check_missing_kwargs(**kwargs)
return prompt
return self.partial_populate(**kwargs)
def partial_populate(self, **kwargs):
"""
Partially populate the template with the given keyword arguments.
Args:
**kwargs: The keyword arguments to populate the template.
Each keyword corresponds to a placeholder in the template.
Returns:
str: The populated template.
"""
self.check_redundant_kwargs(**kwargs)
prompt = []
for literal_text, field_name, format_spec, conversion in self.__parsed_template:
prompt.append(literal_text)
if field_name is None:
continue
if field_name not in kwargs:
if conversion:
value = f"{{{field_name}}}!{conversion}:{format_spec}"
else:
value = f"{{{field_name}:{format_spec}}}"
else:
value = kwargs[field_name]
if conversion is not None:
value = self.__formatter.convert_field(value, conversion)
if format_spec is not None:
value = self.__formatter.format_field(value, format_spec)
prompt.append(value)
return "".join(prompt)
def __add__(self, other):
"""