kotaemon/knowledgehub/post_processing/extractor.py
Nguyen Trung Duc (john) d79b3744cb Simplify the BaseComponent inteface (#64)
This change remove `BaseComponent`'s:

- run_raw
- run_batch_raw
- run_document
- run_batch_document
- is_document
- is_batch

Each component is expected to support multiple types of inputs and a single type of output. Since we want the component to work out-of-the-box with both standardized and customized use cases, supporting multiple types of inputs are expected. At the same time, to reduce the complexity of understanding how to use a component, we restrict a component to only have a single output type.

To accommodate these changes, we also refactor some components to remove their run_raw, run_batch_raw... methods, and to decide the common output interface for those components.

Tests are updated accordingly.

Commit changes:

* Add kwargs to vector store's query
* Simplify the BaseComponent
* Update tests
* Remove support for Python 3.8 and 3.9
* Bump version 0.3.0
* Fix github PR caching still use old environment after bumping version

---------

Co-authored-by: ian <ian@cinnamon.is>
2023-11-13 15:10:18 +07:00

159 lines
4.8 KiB
Python

from __future__ import annotations
import re
from typing import Callable
from theflow import Param
from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document
class ExtractorOutput(Document):
"""
Represents the output of an extractor.
"""
matches: list[str]
class RegexExtractor(BaseComponent):
"""
Simple class for extracting text from a document using a regex pattern.
Args:
pattern (List[str]): The regex pattern(s) to use.
output_map (dict, optional): A mapping from extracted text to the
desired output. Defaults to None.
"""
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
pattern: list[str]
output_map: dict[str, str] | Callable[[str], str] = Param(
default_callback=lambda *_: {}
)
def __init__(self, pattern: str | list[str], **kwargs):
if isinstance(pattern, str):
pattern = [pattern]
super().__init__(pattern=pattern, **kwargs)
@staticmethod
def run_raw_static(pattern: str, text: str) -> list[str]:
"""
Finds all non-overlapping occurrences of a pattern in a string.
Parameters:
pattern (str): The regular expression pattern to search for.
text (str): The input string to search in.
Returns:
List[str]: A list of all non-overlapping occurrences of the pattern in the
string.
"""
return re.findall(pattern, text)
@staticmethod
def map_output(text, output_map) -> str:
"""
Maps the given `text` to its corresponding value in the `output_map` dictionary.
Parameters:
text (str): The input text to be mapped.
output_map (dict): A dictionary containing mapping of input text to output
values.
Returns:
str: The corresponding value from the `output_map` if `text` is found in the
dictionary, otherwise returns the original `text`.
"""
if not output_map:
return text
if isinstance(output_map, dict):
return output_map.get(text, text)
return output_map(text)
def run_raw(self, text: str) -> ExtractorOutput:
"""
Matches the raw text against the pattern and rans the output mapping, returning
an instance of ExtractorOutput.
Args:
text (str): The raw text to be processed.
Returns:
ExtractorOutput: The processed output as a list of ExtractorOutput.
"""
output: list[str] = sum(
[self.run_raw_static(p, text) for p in self.pattern], []
)
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0] if output else "",
matches=output,
metadata={"origin": "RegexExtractor"},
)
def run(
self, text: str | list[str] | Document | list[Document]
) -> list[ExtractorOutput]:
"""Match the input against a pattern and return the output for each input
Parameters:
text: contains the input string to be processed
Returns:
A list contains the output ExtractorOutput for each input
Example:
document1 = Document(...)
document2 = Document(...)
document_batch = [document1, document2]
batch_output = self(document_batch)
# batch_output will be [output1_document1, output1_document2]
"""
# TODO: this conversion seems common
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
output = []
for each_input in input_:
output.append(self.run_raw(each_input))
return output
class FirstMatchRegexExtractor(RegexExtractor):
pattern: list[str]
def run_raw(self, text: str) -> ExtractorOutput:
for p in self.pattern:
output = self.run_raw_static(p, text)
if output:
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0],
matches=output,
metadata={"origin": "FirstMatchRegexExtractor"},
)
return ExtractorOutput(
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
)