Source code for solidipes.utils.progress

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional

from datasize import DataSize
from rich.live import Live
from rich.progress import BarColumn, FileSizeColumn, TextColumn, TimeRemainingColumn
from rich.progress import Progress as RichProgress
from rich.spinner import Spinner as RichSpinner

if TYPE_CHECKING:
    from streamlit.delta_generator import DeltaGenerator


_streamlit_layout: Optional["DeltaGenerator"] = None


[docs] def set_streamlit_layout(layout: "DeltaGenerator") -> None: global _streamlit_layout _streamlit_layout = layout
[docs] def get_streamlit_layout() -> "DeltaGenerator": global _streamlit_layout if _streamlit_layout is None: import streamlit as st _streamlit_layout = st return _streamlit_layout
[docs] class ProgressDisplay(ABC): def __init__(self, description: str) -> None: """Generic progress display interface. Can be used as a context manager.""" self.description = description def __enter__(self) -> "ProgressDisplay": self.reset() self.start() return self def __exit__(self, type, value, traceback): self.close()
[docs] def reset(self) -> None: """Reset the progress display to its initial state."""
[docs] def start(self) -> None: """Open the progress display context."""
[docs] def close(self) -> None: """Close the progress display context."""
[docs] class Spinner(ProgressDisplay): """Generic spinner interface. Can be used as a context manager."""
[docs] class TextSpinner(Spinner): def __init__(self, description: str) -> None: """Spinner using the rich library, for CLI and Jupyter notebooks.""" super().__init__(description) self.spinner = RichSpinner("dots", text=description) self.live = None
[docs] def start(self) -> None: self.live = Live(self.spinner) self.live.__enter__()
def __exit__(self, type, value, traceback): if self.live is not None: self.live.__exit__(type, value, traceback) self.live = None
[docs] class StreamlitSpinner(Spinner): def __init__(self, description: str, layout: Optional["DeltaGenerator"] = None) -> None: """Streamlit spinner.""" super().__init__(description) if layout is None: layout = get_streamlit_layout() self.layout = layout
[docs] def start(self) -> None: import streamlit as st with self.layout: self.spinner = st.spinner(text=self.description) self.spinner.__enter__()
def __exit__(self, type, value, traceback): self.spinner.__exit__(type, value, traceback)
[docs] class ProgressBar(ProgressDisplay): def __init__(self, description: str, total: float = 100, show_datasize: bool = False) -> None: """Generic progress bar interface. Can be used as a context manager.""" super().__init__(description) self.total = total self.show_datasize = show_datasize
[docs] @abstractmethod def update(self, advance: float = 0, text: Optional[str] = None) -> None: """Update the progress bar."""
[docs] class TextProgressBar(ProgressBar): def __init__(self, *args, **kwargs) -> None: """Progress bar using the rich library, for CLI and Jupyter notebooks.""" super().__init__(*args, **kwargs) self.description_style = "[bold blue]" self.progress_style = "[blue]" self.text_style = "[white]" columns = [ BarColumn(), TimeRemainingColumn(), TextColumn("{task.description}"), ] if self.show_datasize: columns.insert(1, FileSizeColumn()) self.bar = RichProgress(*columns) self.task = self.bar.add_task(self.description, total=self.total)
[docs] def reset(self) -> None: self.bar.reset(self.task)
[docs] def start(self) -> None: self.bar.start()
[docs] def _set_display_info(self, text: Optional[str] = None) -> None: """Set the display info (description, progress, text).""" full_text = f"{self.description_style}{self.description}" if not self.show_datasize: full_text += f"{self.progress_style}({self.bar.tasks[self.task].completed}/{self.total})" if text is not None: full_text += f" {self.text_style}{text}" self.bar.update(self.task, description=full_text)
[docs] def update(self, advance: float = 0, text: Optional[str] = None) -> None: self.bar.update(self.task, advance=advance) self._set_display_info(text)
[docs] def close(self) -> None: self._set_display_info() self.bar.stop()
[docs] class StreamlitProgressBar(ProgressBar): def __init__(self, *args, layout: Optional["DeltaGenerator"] = None, **kwargs) -> None: """Progress bar for Streamlit.""" super().__init__(*args, **kwargs) if layout is None: layout = get_streamlit_layout() self.layout = layout self.bar = layout.progress(0, text=self.description) self.current = 0
[docs] def reset(self) -> None: self.bar.empty() self.bar = self.layout.progress(0, text=self.description) self.current = 0
[docs] def update(self, advance: float = 0, text: Optional[str] = None) -> None: self.current += advance full_text = f"**{self.description}**" if self.show_datasize: full_text += f" ({DataSize(self.current):.1a}/{DataSize(self.total):.1a})" else: full_text += f" ({self.current}/{self.total})" if text is not None: full_text += f" {text}" self.bar.progress(100 * self.current // self.total, text=full_text)
[docs] def close(self) -> None: self.bar.empty()
[docs] def get_spinner( description: str, backend: Optional[str] = None, streamlit_layout: Optional["DeltaGenerator"] = None, ) -> Spinner: from ..viewers import backends if backend is None: backend = backends.current_backend if backend == "streamlit": return StreamlitSpinner(description, layout=streamlit_layout) if backend == "python" or backend == "jupyter notebook": return TextSpinner(description) raise ValueError(f"Backend {backend} not supported")
[docs] def get_progress_bar( description: str, total: float = 100, show_datasize: bool = False, backend: Optional[str] = None, streamlit_layout: Optional["DeltaGenerator"] = None, ) -> ProgressBar: from ..viewers import backends if backend is None: backend = backends.current_backend if backend == "streamlit": return StreamlitProgressBar(description, total=total, show_datasize=show_datasize, layout=streamlit_layout) if backend == "python" or backend == "jupyter notebook": return TextProgressBar(description, total=total, show_datasize=show_datasize) raise ValueError(f"Backend {backend} not supported")