# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import mmengine.dist as dist
import rich.progress as progress
from rich.live import Live

disable_progress_bar = False
global_progress = progress.Progress(
    '{task.description}',
    progress.BarColumn(),
    progress.TaskProgressColumn(show_speed=True),
    progress.TimeRemainingColumn(),
)
global_live = Live(global_progress, refresh_per_second=10)


def track(sequence, description: str = '', total: Optional[float] = None):
    if disable_progress_bar:
        yield from sequence
    else:
        global_live.start()
        task_id = global_progress.add_task(description, total=total)
        task = global_progress._tasks[task_id]
        try:
            yield from global_progress.track(sequence, task_id=task_id)
        finally:
            if task.total is None:
                global_progress.update(task_id, total=task.completed)
            if all(task.finished for task in global_progress.tasks):
                global_live.stop()
                for task_id in global_progress.task_ids:
                    global_progress.remove_task(task_id)


def track_on_main_process(sequence, description='', total=None):
    if not dist.is_main_process() or disable_progress_bar:
        yield from sequence
    else:
        yield from track(sequence, total=total, description=description)
