from pathlib import Path

import fire
from beartype import beartype
import wandb
from wandb.errors import CommError
from wandb.apis.public.runs import Run

from helpers import logger


@beartype
def download_file(file_name: str, run: Run, run_dir: Path):
    try:
        file = run.file(file_name)
        file.download(root=str(run_dir), replace=True)
        logger.warn(f"\U0001F525downloaded {file_name} for run {run.name}")
        logger.warn(f" @@ {run_dir.resolve()}")
    except CommError:
        logger.warn(f"{file_name} not found for run {run.name}")


@beartype
def retrieve_from_wandb(wandb_id: str,
                        wandb_project: str,
                        group_name: str,
                        nickname: str,
                        download_dir: str):
    """Retrieve the progress.json and progress.csv files """

    # initialize wandb API
    api = wandb.Api()

    # search for runs in the specified group
    runs = api.runs(f"{wandb_id}/{wandb_project}", filters={"group": group_name})

    _, _, a, b, c, *_ = group_name.split(".")
    new_path_root = Path(download_dir) / a / b / c / nickname

    for run in runs:
        run_dir = new_path_root / run.name.split(".")[-1]
        run_dir.mkdir(parents=True, exist_ok=True)

        # download the specific files for each run
        for file_name in ["progress.json", "progress.csv"]:
            download_file(file_name, run, run_dir)


if __name__ == "__main__":
    logger.configure(directory=None, format_strs=["stdout"])
    logger.set_level(logger.WARN)
    fire.Fire(retrieve_from_wandb)
