import wandb


def wandb_watch(model):
    if wandb.run is not None:
        wandb.watch(model, log="all")


def log_trainer_metrics(rec_loss, commit_loss, val_loss, scheduler, epoch):
    if wandb.run is not None:
        from src.train_utils.wandb_utils import wandb_log

        wandb_log({"val_loss": val_loss, "train_recon_loss_epoch": rec_loss, "train_commit_loss_epoch": commit_loss, "lr": scheduler.get_last_lr()[0], "epoch": epoch})


def gen_name(args):
    import sys

    name = f"{args.model}_{args.codebook}_thr{args.threshold_ema_dead_code}_cw{args.commit_weight}_{args.codebook_size}x{args.codebook_dim}d_dec{args.vq_decay}"
    skip = ["model", "codebook", "threshold_ema_dead_code", "commit_weight", "codebook_size", "codebook_dim", "vq_decay"]
    skip += ["batch_size", "seed"]  # these are fixed
    skip = set(skip)
    # add sys.argv if not in name
    for arg in sys.argv[1:]:
        if arg.startswith("--"):
            last = arg
            if "=" in arg:
                if last[2:].split("=")[0] in skip:
                    continue
                name += f"_{last[2:]}"
        if not arg.startswith("--"):
            if last[2:] in skip:
                continue
            name += f"_{last[2:]}={arg}"
    return name


def init_wandb(args):
    wandb.init(
        entity="your_wandb_name",
        project="vqvae_ablation",
        config=vars(args),
        name=gen_name(args),
    )


logged = []


def wandb_log(*args, **kwargs):
    import torch.distributed as dist
    import src.train_utils.trainer as trainer

    if dist.get_rank() == 0:
        assert isinstance(args[0], dict)
        args[0]["eff_step"] = trainer.global_step / wandb.config.n_accumulate
        wandb.log(*args, **kwargs, step=trainer.global_step)

        # add to logged
        logged.append((args, kwargs, trainer.global_step))
        # save every 1000 logs
        if len(logged) % 1000 == 0:
            import torch
            from pathlib import Path

            save_path = f"{trainer.save_path}/logged.pt"
            Path(save_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save(logged, save_path)
