from typing import Optional
import os
import platform
import wandb
from omegaconf import DictConfig, OmegaConf

from pado.utils.dist_utils import is_master, broadcast_objects, is_distributed_set

__all__ = ["set_wandb"]


def set_wandb(cfg: DictConfig, force_mode: Optional[str] = None) -> Optional[str]:
    # cfg should be "FULL" configuration, so that WandB will save.
    run_type = cfg["run_type"]
    save_dir = cfg["save_dir"]  # root save dir

    wandb_mode = cfg["wandb"]["mode"].lower()
    if force_mode is not None:
        wandb_mode = force_mode.lower()
    if wandb_mode not in ("online", "offline", "disabled"):
        raise ValueError(f"WandB mode {wandb_mode} invalid.")

    if is_master():  # wandb init only at master
        os.makedirs(save_dir, exist_ok=True)

        wandb_project = cfg["project"]
        wandb_name = cfg["name"]

        wandb_note = cfg["wandb"]["notes"] if "notes" in cfg["wandb"] else None
        wandb_id = cfg["wandb"]["id"] if "id" in cfg["wandb"] else None
        server_name = platform.node()
        wandb_note = server_name + (f"-{wandb_note}" if (wandb_note is not None) else "")

        save_cfg = OmegaConf.to_container(cfg, resolve=True)
        wandb.init(project=wandb_project,
                   job_type=run_type,
                   name=wandb_name,
                   dir=save_dir,
                   resume="allow",
                   mode=wandb_mode,
                   notes=wandb_note,
                   config=save_cfg,
                   id=wandb_id)

        wandb_path = wandb.run.dir if (wandb_mode != "disabled") else save_dir
    else:
        wandb_path = None

    if is_distributed_set():
        wandb_path = broadcast_objects([wandb_path], src_rank=0)[0]

    # if wandb_path is None:
    #     raise ValueError(f"wandb_path None for {dist.get_rank()} GPU")
    return wandb_path
