Source code for solidipes.utils.progress

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional

from datasize import DataSize
from rich.live import Live
from rich.progress import BarColumn, FileSizeColumn
from rich.progress import Progress as RichProgress
from rich.progress import TextColumn, TimeRemainingColumn
from rich.spinner import Spinner as RichSpinner

if TYPE_CHECKING:
    from streamlit.delta_generator import DeltaGenerator


_streamlit_layout: Optional["DeltaGenerator"] = None


[docs] def set_streamlit_layout(layout: "DeltaGenerator"): global _streamlit_layout _streamlit_layout = layout
[docs] def get_streamlit_layout() -> "DeltaGenerator": global _streamlit_layout if _streamlit_layout is None: import streamlit as st _streamlit_layout = st return _streamlit_layout
[docs] class ProgressDisplay(ABC): def __init__(self, description: str): """Generic progress display interface. Can be used as a context manager.""" self.description = description def __enter__(self) -> "ProgressDisplay": self.reset() self.start() return self def __exit__(self, type, value, traceback): self.close()
[docs] def reset(self) -> None: """Reset the progress display to its initial state."""
[docs] def start(self) -> None: """Open the progress display context."""
[docs] def close(self) -> None: """Close the progress display context."""
[docs] class Spinner(ProgressDisplay): """Generic spinner interface. Can be used as a context manager."""
[docs] class TextSpinner(Spinner): def __init__(self, description: str): """Spinner using the rich library, for CLI and Jupyter notebooks.""" super().__init__(description) self.spinner = RichSpinner("dots", text=description) self.live = None
[docs] def start(self) -> None: self.live = Live(self.spinner) self.live.__enter__()
def __exit__(self, type, value, traceback): if self.live is not None: self.live.__exit__(type, value, traceback) self.live = None
[docs] class StreamlitSpinner(Spinner): def __init__(self, description: str, layout: Optional["DeltaGenerator"] = None): """Streamlit spinner.""" super().__init__(description) if layout is None: layout = get_streamlit_layout() self.layout = layout
[docs] def start(self) -> None: import streamlit as st with self.layout: self.spinner = st.spinner(text=self.description) self.spinner.__enter__()
def __exit__(self, type, value, traceback): self.spinner.__exit__(type, value, traceback)
[docs] class ProgressBar(ProgressDisplay): def __init__(self, description: str, total: float = 100, show_datasize: bool = False): """Generic progress bar interface. Can be used as a context manager.""" super().__init__(description) self.total = total self.show_datasize = show_datasize
[docs] @abstractmethod def update(self, advance: float = 0, text: Optional[str] = None) -> None: """Update the progress bar."""
[docs] class TextProgressBar(ProgressBar): def __init__(self, *args, **kwargs): """Progress bar using the rich library, for CLI and Jupyter notebooks.""" super().__init__(*args, **kwargs) self.description_style = "[bold blue]" self.progress_style = "[blue]" self.text_style = "[white]" columns = [ BarColumn(), TimeRemainingColumn(), TextColumn("{task.description}"), ] if self.show_datasize: columns.insert(1, FileSizeColumn()) self.bar = RichProgress(*columns) self.task = self.bar.add_task(self.description, total=self.total)
[docs] def reset(self) -> None: self.bar.reset(self.task)
[docs] def start(self) -> None: self.bar.start()
[docs] def _set_display_info(self, text: Optional[str] = None) -> None: """Set the display info (description, progress, text).""" full_text = f"{self.description_style}{self.description}" if not self.show_datasize: full_text += f"{self.progress_style}({self.bar.tasks[self.task].completed}/{self.total})" if text is not None: full_text += f" {self.text_style}{text}" self.bar.update(self.task, description=full_text)
[docs] def update(self, advance: float = 0, text: Optional[str] = None) -> None: self.bar.update(self.task, advance=advance) self._set_display_info(text)
[docs] def close(self) -> None: self._set_display_info() self.bar.stop()
[docs] class StreamlitProgressBar(ProgressBar): def __init__(self, *args, layout: Optional["DeltaGenerator"] = None, **kwargs): """Progress bar for Streamlit.""" super().__init__(*args, **kwargs) if layout is None: layout = get_streamlit_layout() self.bar = layout.progress(0, text=self.description) self.current = 0
[docs] def reset(self) -> None: self.current = 0
[docs] def update(self, advance: float = 0, text: Optional[str] = None) -> None: self.current += advance full_text = f"**{self.description}**" if self.show_datasize: full_text += f" ({DataSize(self.current):.1a}/{DataSize(self.total):.1a})" else: full_text += f" ({self.current}/{self.total})" if text is not None: full_text += f" {text}" self.bar.progress(100 * self.current // self.total, text=full_text)
[docs] def close(self) -> None: self.bar.empty()
[docs] def get_spinner( description: str, backend: Optional[str] = None, streamlit_layout: Optional["DeltaGenerator"] = None, ) -> Spinner: from ..viewers import backends if backend is None: backend = backends.current_backend if backend == "streamlit": return StreamlitSpinner(description, layout=streamlit_layout) if backend == "python" or backend == "jupyter notebook": return TextSpinner(description) raise ValueError(f"Backend {backend} not supported")
[docs] def get_progress_bar( description: str, total: float = 100, show_datasize: bool = False, backend: Optional[str] = None, streamlit_layout: Optional["DeltaGenerator"] = None, ) -> ProgressBar: from ..viewers import backends if backend is None: backend = backends.current_backend if backend == "streamlit": return StreamlitProgressBar(description, total=total, show_datasize=show_datasize, layout=streamlit_layout) if backend == "python" or backend == "jupyter notebook": return TextProgressBar(description, total=total, show_datasize=show_datasize) raise ValueError(f"Backend {backend} not supported")