kotaemon/knowledgehub/post_processing/extractor.py
ian_Cin 84f1fa8cbd [AUR-395] Adopt Example1 disclaimer pipeline (#42)
* Adopt Example1 disclaimer pipeline
* Update Document class
* Add composite components
* Modify Extractor behaviours
2023-10-10 15:42:48 +07:00

209 lines
6.2 KiB
Python

import re
from typing import Callable, Dict, List, Union
from theflow import Param
from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document
class ExtractorOutput(Document):
"""
Represents the output of an extractor.
"""
matches: List[str]
class RegexExtractor(BaseComponent):
"""
Simple class for extracting text from a document using a regex pattern.
Args:
pattern (List[str]): The regex pattern(s) to use.
output_map (dict, optional): A mapping from extracted text to the
desired output. Defaults to None.
"""
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
pattern: List[str]
output_map: Union[Dict[str, str], Callable[[str], str]] = Param(
default_callback=lambda *_: {}
)
def __init__(self, pattern: Union[str, List[str]], **kwargs):
if isinstance(pattern, str):
pattern = [pattern]
super().__init__(pattern=pattern, **kwargs)
@staticmethod
def run_raw_static(pattern: str, text: str) -> List[str]:
"""
Finds all non-overlapping occurrences of a pattern in a string.
Parameters:
pattern (str): The regular expression pattern to search for.
text (str): The input string to search in.
Returns:
List[str]: A list of all non-overlapping occurrences of the pattern in the
string.
"""
return re.findall(pattern, text)
@staticmethod
def map_output(text, output_map) -> str:
"""
Maps the given `text` to its corresponding value in the `output_map` dictionary.
Parameters:
text (str): The input text to be mapped.
output_map (dict): A dictionary containing mapping of input text to output
values.
Returns:
str: The corresponding value from the `output_map` if `text` is found in the
dictionary, otherwise returns the original `text`.
"""
if not output_map:
return text
if isinstance(output_map, dict):
return output_map.get(text, text)
return output_map(text)
def run_raw(self, text: str) -> ExtractorOutput:
"""
Matches the raw text against the pattern and rans the output mapping, returning
an instance of ExtractorOutput.
Args:
text (str): The raw text to be processed.
Returns:
ExtractorOutput: The processed output as a list of ExtractorOutput.
"""
output = sum(
[self.run_raw_static(p, text) for p in self.pattern], []
) # type: List[str]
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0] if output else "",
matches=output,
metadata={"origin": "RegexExtractor"},
)
def run_batch_raw(self, text_batch: List[str]) -> List[ExtractorOutput]:
"""
Runs a batch of raw text inputs through the `run_raw()` method and returns the
output for each input.
Parameters:
text_batch (List[str]): A list of raw text inputs to process.
Returns:
List[ExtractorOutput]: A list containing the output for each input in the
batch.
"""
batch_output = [self.run_raw(each_text) for each_text in text_batch]
return batch_output
def run_document(self, document: Document) -> ExtractorOutput:
"""
Run the document through the regex extractor and return an extracted document.
Args:
document (Document): The input document.
Returns:
ExtractorOutput: The extracted content.
"""
return self.run_raw(document.text)
def run_batch_document(
self, document_batch: List[Document]
) -> List[ExtractorOutput]:
"""
Runs a batch of documents through the `run_document` function and returns the
output for each document.
Parameters:
document_batch (List[Document]): A list of Document objects representing the
batch of documents to process.
Returns:
List[ExtractorOutput]: A list contains the output ExtractorOutput for each
input Document in the batch.
Example:
document1 = Document(...)
document2 = Document(...)
document_batch = [document1, document2]
batch_output = self.run_batch_document(document_batch)
# batch_output will be [output1_document1, output1_document2]
"""
batch_output = [
self.run_document(each_document) for each_document in document_batch
]
return batch_output
def is_document(self, text) -> bool:
"""
Check if the given text is an instance of the Document class.
Args:
text: The text to check.
Returns:
bool: True if the text is an instance of Document, False otherwise.
"""
if isinstance(text, Document):
return True
return False
def is_batch(self, text) -> bool:
"""
Check if the given text is a batch of documents.
Parameters:
text (List): The text to be checked.
Returns:
bool: True if the text is a batch of documents, False otherwise.
"""
if not isinstance(text, List):
return False
if len(set(self.is_document(each_text) for each_text in text)) <= 1:
return True
return False
class FirstMatchRegexExtractor(RegexExtractor):
pattern: List[str]
def run_raw(self, text: str) -> ExtractorOutput:
for p in self.pattern:
output = self.run_raw_static(p, text)
if output:
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0],
matches=output,
metadata={"origin": "FirstMatchRegexExtractor"},
)
return ExtractorOutput(
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
)