diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index 410e315..deb1ac3 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -1059,15 +1059,18 @@ class FileIndexPage(BasePage): """Handle zip files""" zip_files = [file for file in files if file.endswith(".zip")] remaining_files = [file for file in files if not file.endswith("zip")] + errors: list[str] = [] # Clean-up before unzip to remove old files shutil.rmtree(zip_dir, ignore_errors=True) + # Unzip for zip_file in zip_files: # Prepare new zip output dir, separated for each files basename = os.path.splitext(os.path.basename(zip_file))[0] zip_out_dir = os.path.join(zip_dir, basename) os.makedirs(zip_out_dir, exist_ok=True) + with zipfile.ZipFile(zip_file, "r") as zip_ref: zip_ref.extractall(zip_out_dir) @@ -1084,7 +1087,7 @@ class FileIndexPage(BasePage): if n_zip_file > 0: print(f"Update zip files: {n_zip_file}") - return remaining_files + return remaining_files, errors def index_fn( self, files, urls, reindex: bool, settings, user_id @@ -1100,20 +1103,22 @@ class FileIndexPage(BasePage): """ if urls: files = [it.strip() for it in urls.split("\n")] - errors = [] + errors = self.validate_urls(files) else: if not files: gr.Info("No uploaded file") yield "", "" return + files, unzip_errors = self._may_extract_zip( + files, flowsettings.KH_ZIP_INPUT_DIR + ) + errors = self.validate_files(files) + errors.extend(unzip_errors) - files = self._may_extract_zip(files, flowsettings.KH_ZIP_INPUT_DIR) - - errors = self.validate(files) - if errors: - gr.Warning(", ".join(errors)) - yield "", "" - return + if errors: + gr.Warning(", ".join(errors)) + yield "", "" + return gr.Info(f"Start indexing {len(files)} files...") @@ -1569,7 +1574,7 @@ class FileIndexPage(BasePage): selected_item["files"], ) - def validate(self, files: list[str]): + def validate_files(self, files: list[str]): """Validate if the files are valid""" paths = [Path(file) for file in files] errors = [] @@ -1598,6 +1603,14 @@ class FileIndexPage(BasePage): return errors + def validate_urls(self, urls: list[str]): + """Validate if the urls are valid""" + errors = [] + for url in urls: + if not url.startswith("http") and not url.startswith("https"): + errors.append(f"Invalid url `{url}`") + return errors + class FileSelector(BasePage): """File selector UI in the Chat page"""