"""A custom progress bar with rich."""
from typing import Iterable, Optional, Any

from rich.text import Text
from rich.progress import (
    Task,
    BarColumn,
    MofNCompleteColumn,
    Progress,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
    ProgressColumn
)


class IterationsPerSecondColumn(ProgressColumn):
    """A column for displaying the number of iterations per second."""

    def render(self, task: Task) -> Text:
        """Show iterations per second."""
        speed = task.finished_speed or task.speed
        if speed is None:
            return Text('?', style='progress.data.speed')
        data_speed = f'{speed:,.0f}'
        return Text(f'{data_speed} it/s', style='progress.data.speed')


class ProgressBar:
    """A progress bar with support for custom fields.
    
    Args:
        sequence: The sequence to iterate over.
        total: The total number of iterations. If not provided, it will be
            inferred from the length of the sequence.
        description: The description to display on the progress bar.

    Example:

        >>> from time import sleep
        >>> with ProgressBar(range(10), description='Sleeping...') as bar:
        ...     for i in bar:
        ...         sleep(0.1)
        ...         bar.write_field('i_squared', i**2)
    """

    sequence: Iterable
    total: float
    description: str

    # Private Instance Attributes:
    #   _progress: The progress bar.
    #   _task: The task of the progress bar.
    #   _current_fields: The current field set for this iteration.
    _progress: Progress
    _task: Task
    _current_fields: dict[str, Any]
    
    def __init__(
        self,
        sequence: Iterable,
        total: Optional[float] = None,
        description: str = 'Working...'
    ) -> None:
        self.sequence = sequence
        self.total = total or len(sequence)
        self.description = description

        self._progress = Progress(
            TextColumn('{task.description}', justify='right'),
            TextColumn('[progress.percentage]{task.percentage:>3.0f}%'),
            BarColumn(bar_width=None),
            MofNCompleteColumn(),
            # [ {time_elapsed} < {time_remaining}, {iterations_per_second} it/s ]
            TextColumn(' ['),
            TimeElapsedColumn(),
            TextColumn('<'),
            TimeRemainingColumn(),
            TextColumn(','),
            IterationsPerSecondColumn(),
            TextColumn('] '),
            TextColumn(''),
            TextColumn('{task.fields[message]}', justify='right'),
        )
        self._task = self._progress.add_task(
            self.description,
            total=self.total,
            message=''
        )

        self._current_fields = {}

    def write_field(self, name: str, value: Any) -> None:
        """Write a field to the progress bar, overwriting the previous value.

        Args:
            name: The name of the field.
            value: The value of the field.
        """
        self._current_fields[name] = value
        
    def __iter__(self) -> Iterable:
        seq = self._progress.track(
            self.sequence,
            total=self.total,
            task_id=self._task
        )
        for x in seq:
            yield x

            if len(self._current_fields) == 0:
                continue

            message = ', '.join(f'{k}={v}' for k, v in self._current_fields.items())
            self._progress.update(self._task, message=message)

            # Clear the current field values for the next iteration
            self._current_fields = {}
    
    def __enter__(self) -> 'ProgressBar':
        self._progress.__enter__()  
        return self
    
    def __exit__(self, *args) -> None:
        self._progress.__exit__(*args)

    def __getattr__(self, name: str) -> Any:
        if name in self.__dict__:
            return self.__dict__[name]
        return getattr(self._progress, name)
