Source code for solidipes.scanners.scanner

import os
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Optional, Protocol, Set, Union

from tqdm import tqdm

from ..loaders.file import File, load_file
from ..loaders.group import Group, load_groups
from ..utils import default_ignore_patterns, get_ignore, logging, solidipes_dirname

################################################################


print = logging.invalidPrint
logger = logging.getLogger()

################################################################


[docs] class DictTree(dict): def __init__(self, *args, **kwargs): """A nested dictionary that counts the number of leaves under each node.""" super().__init__(*args, **kwargs) # Convert all nested dictionaries to DictTree for key, value in self.items(): if isinstance(value, dict): self[key] = DictTree(value) self.count = self.count_leaves()
[docs] def count_leaves(self) -> int: """Count the number of leaves in the tree.""" count = 0 for value in self.values(): if isinstance(value, DictTree): count += value.count else: count += 1 return count
[docs] def flatten( self, value_func: Callable = lambda value: value, keys_join_func: Callable[[list[str]], str] = lambda keys: os.path.join(*keys), add_dicts: bool = False, dict_func: Callable = lambda _: None, keys: list[str] = [], ) -> dict: """Flatten the tree into a flat dictionary.""" flattened = {} if add_dicts: joined_keys = keys_join_func(keys) if len(keys) > 0 else "." flattened[joined_keys] = dict_func(self) for key, value in sorted(self.items()): new_keys = keys + [key] joined_keys = keys_join_func(new_keys) if isinstance(value, DictTree): flattened.update( value.flatten( value_func=value_func, keys_join_func=keys_join_func, add_dicts=add_dicts, dict_func=dict_func, keys=new_keys, ) ) else: flattened[joined_keys] = value_func(value) return flattened
[docs] def filter( self, value_filter: Callable = lambda _: True, keys_join_func: Callable[[list[str]], str] = lambda keys: os.path.join(*keys), joined_keys_filter: Callable = lambda _: True, keep_empty_dicts: bool = False, keys: list[str] = [], ) -> "DictTree": """Filter the tree based on the values and keys. Both filters must be satisfied.""" filtered = DictTree() for key, value in sorted(self.items()): new_keys = keys + [key] joined_keys = keys_join_func(new_keys) if isinstance(value, DictTree): sub_tree = value.filter( value_filter=value_filter, keys_join_func=keys_join_func, joined_keys_filter=joined_keys_filter, keep_empty_dicts=keep_empty_dicts, keys=new_keys, ) if len(sub_tree) > 0 or keep_empty_dicts: filtered[key] = sub_tree elif value_filter(value) and joined_keys_filter(joined_keys): filtered[key] = value return filtered
[docs] def apply( self, func: Callable, ) -> "DictTree": """Apply a function to all values in the tree.""" applied = DictTree() for key, value in self.items(): if isinstance(value, DictTree): applied[key] = value.apply(func) else: applied[key] = func(value) return applied
[docs] def reduce( self, func: Callable, initial: Any, ) -> Any: """Reduce the tree to a single value.""" acc = initial for value in self.values(): if isinstance(value, DictTree): acc = value.reduce(func, acc) else: acc = func(acc, value) return acc
Loader = Union[File, Group] FilepathTree = DictTree # dict[str, "FilepathTree | str"] LoaderTree = DictTree # dict[str, "LoaderTree | Loader"]
[docs] class ProgressBar(Protocol): total: float
[docs] def update(self, n: Optional[float]) -> Optional[bool]: ...
[docs] def set_postfix_str(self, desc: str) -> None: ...
[docs] def reset(self) -> None: ...
[docs] def close(self) -> None: ...
[docs] class StreamlitProgressBar(ProgressBar): def __init__(self, text, container=None): import streamlit as st if container is None: container = st self.st_bar = container.progress(0, text=text) self.text = text self.total = 0 self.current = 0 self.postfix = ""
[docs] def update(self, value): self.current += value self._update()
[docs] def set_postfix_str(self, desc: str): self.postfix = desc self._update()
[docs] def _update(self): text = f"{self.text} ({self.current}/{self.total}) {self.postfix}" self.st_bar.progress(100 * self.current // self.total, text=text)
[docs] def reset(self): self.current = 0
[docs] def close(self): self.st_bar.empty()
[docs] def cached_scan(func: Callable) -> Callable: """Decorator to cache the result of the scan. Adds a "force_rescan" parameter to the decorated function. Assumes that the result of the scan only depends on root_path and excluded_patterns. """ @lru_cache(maxsize=1) def cached_func(self, root_path: str, excluded_patterns: frozenset[str], *args, **kwargs): logger.debug(f"Scanning with {func.__name__}") return func(self, *args, **kwargs) def wrapper(self, *args, force_rescan: bool = False, **kwargs): if force_rescan: cached_func.cache_clear() return cached_func(self, self.root_path, frozenset(self.excluded_patterns), *args, **kwargs) return wrapper
[docs] class Scanner: """A class to scan a directory to load files and groups. All paths are given relative to the scanner's root path. """ def __init__(self, root_path: str = "."): self.root_path = root_path try: # Get ignored patterns from .solidipes self.excluded_patterns = get_ignore() except FileNotFoundError: self.excluded_patterns = default_ignore_patterns.copy() self.progress_bar: Optional[ProgressBar] = None
[docs] @cached_scan def get_filepath_tree(self) -> FilepathTree: """Get a tree of all filepaths, organized by directory.""" tree = {} for root, dirs, files in os.walk(self.root_path): dirpath = os.path.relpath(root, self.root_path) if self.is_excluded(dirpath): logger.debug(f"Exclude {dirpath}") dirs.clear() continue # Create the directory structure in the tree current_tree = tree for dirname in dirpath.split(os.sep): if dirname == ".": continue if dirname not in current_tree: current_tree[dirname] = {} current_tree = current_tree[dirname] # Add filepaths to the tree for file in files: filepath = os.path.relpath(os.path.join(dirpath, file)) if self.is_excluded(filepath): logger.debug(f"Exclude {filepath}") continue current_tree[file] = filepath return DictTree(tree)
[docs] @cached_scan def get_dirpath_tree(self) -> FilepathTree: """Get a tree of all directory paths.""" return self.get_filepath_tree().filter( value_filter=lambda _: False, keep_empty_dicts=True, )
[docs] @cached_scan def get_path_list(self) -> list[str]: """Get a list of all paths (files and directories).""" return list( self.get_filepath_tree() .flatten( value_func=lambda _: None, add_dicts=True, ) .keys() )
[docs] @cached_scan def get_filepath_list(self) -> list[str]: """Get a list of all file paths.""" return list( self.get_filepath_tree() .flatten( value_func=lambda _: None, ) .keys() )
[docs] @cached_scan def get_loader_tree( self, ) -> LoaderTree: """Get a tree of loaders, with groups, organized by directory.""" using_self_progress_bar = False if self.progress_bar is None: using_self_progress_bar = True self.progress_bar = tqdm(desc="Loading files") tree = self.get_filepath_tree() self.progress_bar.total = tree.count self.progress_bar.reset() tree = DictTree( convert_filepath_tree_to_loader_tree( tree=tree, root_path=self.root_path, progress_bar=self.progress_bar, ) ) self.progress_bar.close() if using_self_progress_bar: self.progress_bar = None return tree
[docs] def get_filtered_loader_tree( self, dirs: list[str] = [], recursive: bool = True, ) -> LoaderTree: """Get a tree of loaders for the given directories.""" if recursive: def path_filter(path: str): return any(path.startswith(d) for d in dirs) else: def path_filter(path: str): return any(os.path.dirname(path) == d for d in dirs) return self.get_loader_tree().filter( joined_keys_filter=path_filter, )
[docs] @cached_scan def get_loader_dict( self, ) -> dict[str, Loader]: """Get a dictionary mapping paths (potentially grouped) to loaders.""" return self.get_loader_tree().flatten()
[docs] def get_filtered_loader_dict( self, dirs: list[str] = [], recursive: bool = True, ) -> dict[str, Loader]: """Get a dictionary mapping paths (potentially grouped) to loaders.""" return self.get_filtered_loader_tree(dirs, recursive=recursive).flatten()
[docs] @cached_scan def get_loader_path_list( self, ) -> list[str]: """Get a list of all loaded paths (potentially grouped).""" return list(self.get_loader_dict().keys())
[docs] def scan(self): """Trigger the creation of loaders.""" self.get_loader_tree()
[docs] def is_excluded(self, path: str, excluded_patterns: Optional[Set[str]] = None) -> bool: """Check whether the provided path is excluded by any of the scanner's patterns""" if excluded_patterns is None: excluded_patterns = self.excluded_patterns p = Path(path) for pattern in excluded_patterns: if pattern == ".": return True # If the pattern ends with a trailing slash, test whether the path is a directory if pattern.endswith("/"): if p.match(pattern) and p.is_dir(): return True # Otherwise, only test whether the path matches the pattern else: if p.match(pattern): return True return False
[docs] @cached_scan def get_modified_time( self, ) -> float: """Get the most recent modified time of all files.""" return self.get_filepath_tree().reduce( func=lambda acc, value: max(acc, os.path.getmtime(value)), initial=0, )
[docs] @cached_scan def get_total_size( self, ) -> int: """Get the total size of all files.""" return self.get_filepath_tree().reduce( func=lambda acc, value: acc + os.path.getsize(value), initial=0, )
[docs] class ExportScanner(Scanner): """A scanner that keeps the .solidipes directory. Individual paths inside .solidipes can still be excluded.""" def __init__(self, root_path: str = "."): super().__init__(root_path) if solidipes_dirname in self.excluded_patterns: self.excluded_patterns.remove(solidipes_dirname)
[docs] def is_excluded(self, path: str, excluded_patterns: Optional[Set[str]] = None) -> bool: """Check whether the provided path is excluded by any of the scanner's patterns""" if excluded_patterns is None: excluded_patterns = self.excluded_patterns # Create a set of excluded patterns specific to the .solidipes directory # Typically: removes `.*` from the set of excluded patterns if solidipes_dirname in path: solidipes_excluded_patterns = set() for pattern in excluded_patterns: if solidipes_dirname in pattern and pattern != solidipes_dirname: solidipes_excluded_patterns.add(pattern) else: solidipes_excluded_patterns = excluded_patterns return super().is_excluded(path, solidipes_excluded_patterns)
[docs] def convert_filepath_tree_to_loader_tree( tree: FilepathTree, root_path: str, progress_bar: Optional[ProgressBar] = None, ) -> LoaderTree: """Convert a tree of filepaths to a tree of loaders, while detecting file groups.""" loaders = {} if progress_bar is not None: progress_bar.set_postfix_str(root_path) # Load groups is_dir_path_dict = {key: isinstance(value, dict) for key, value in tree.items()} loaded_groups, remaining_is_dir_path_dict = load_groups(is_dir_path_dict, root_path) loaders.update(loaded_groups) # Update progressbar for groups if progress_bar is not None: processed = set(tree.keys()) - set(remaining_is_dir_path_dict.keys()) for key in processed: if isinstance(tree[key], DictTree): progress_bar.update(tree[key].count) else: progress_bar.update(1) # Load files filenames = [name for name, is_dir in is_dir_path_dict.items() if not is_dir] for name in filenames: if progress_bar is not None: progress_bar.set_postfix_str(os.path.join(root_path, name)) filepath = os.path.join(root_path, name) loaders[name] = load_file(filepath) if progress_bar is not None: progress_bar.update(1) # Load subdirectories dirnames = {name for name, is_dir in is_dir_path_dict.items() if is_dir} for dirname in dirnames: subdir_tree: FilepathTree = tree[dirname] # type: ignore subdir_root_path = os.path.join(root_path, dirname) subdir_loaders = convert_filepath_tree_to_loader_tree( tree=subdir_tree, root_path=subdir_root_path, progress_bar=progress_bar, ) loaders[dirname] = subdir_loaders return loaders
[docs] def list_files(found, current_dir=""): items = [] for k, v in found.items(): full_dir = os.path.join(current_dir, k) items.append((full_dir, v)) if isinstance(v, dict): items += list_files(v, current_dir=full_dir) return items