#!python

import logging
import shutil
from pathlib import Path
from typing import List

import typer
import wandb

from nn_core.ui import WANDB_DIR

pylogger = logging.getLogger(__name__)


def download(entity: str, project: str, run_ids: List[str]) -> None:
    for run_id in run_ids:
        matching_runs: List[Path] = [
            item for item in WANDB_DIR.iterdir() if item.is_dir() and item.name.endswith(run_id)
        ]

        if len(matching_runs) > 1:
            raise ValueError(f"More than one run matching unique id {run_id}! Are you sure about that?")

        if len(matching_runs) == 1:
            matching_run = matching_runs[0]
            ckpt_path: Path = sorted(matching_run.rglob("checkpoints/*"), reverse=True)[0]
            typer.secho(f"Reusing local copy {entity}/{project}/{run_id} checkpoint: {ckpt_path.name}")
            shutil.copy(ckpt_path, Path.cwd() / f"{run_id}{''.join(ckpt_path.suffixes)}")
        else:
            api = wandb.Api()
            run = api.run(path=f"{entity}/{project}/{run_id}")
            ckpt_wandb_file = sorted((file for file in run.files() if "checkpoint" in file.name), reverse=True)[0]
            typer.secho(f"Downloading from {entity}/{project}/{run_id} checkpoint: {ckpt_wandb_file}")
            ckpt_path = Path(ckpt_wandb_file.download(root=Path.cwd()).name)
            ckpt_path.rename(Path.cwd() / f"{run_id}{''.join(ckpt_path.suffixes)}")
            (Path.cwd() / "run_files" / "checkpoints").rmdir()
            (Path.cwd() / "run_files").rmdir()


if __name__ == "__main__":
    typer.run(download)
