import shutil
import os

import argparse
import yaml
import torch

from utilities.data.dataset import AudioDataset, AudioImageDataset

from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything
from audioldm_train.utilities.tools import get_restore_step
from audioldm_train.utilities.model_util import instantiate_from_config
from audioldm_train.utilities.tools import build_dataset_json_from_list


def infer(dataset_json, configs, config_yaml_path, exp_group_name, exp_name):
    if "seed" in configs.keys():
        seed_everything(configs["seed"])
    else:
        print("SEED EVERYTHING TO 0")
        seed_everything(0)

    if "precision" in configs.keys():
        torch.set_float32_matmul_precision(configs["precision"])

    log_path = configs["log_directory"]

    if "dataloader_add_ons" in configs["data"].keys():
        dataloader_add_ons = configs["data"]["dataloader_add_ons"]
    else:
        dataloader_add_ons = []

    # val_dataset = AudioDataset(
    #     configs, split="test", add_ons=dataloader_add_ons, dataset_json=dataset_json
    # )

    # val_dataset = AudioImageDataset(
    #     configs, split="test", add_ons=dataloader_add_ons, dataset_json=dataset_json
    # )

    val_dataset = AudioImageDataset(
        configs, split="test", add_ons=dataloader_add_ons, dataset_json=dataset_json, customize_attn=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
    )

    try:
        config_reload_from_ckpt = configs["reload_from_ckpt"]
    except:
        config_reload_from_ckpt = None

    checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")

    wandb_path = os.path.join(log_path, exp_group_name, exp_name)

    os.makedirs(checkpoint_path, exist_ok=True)
    shutil.copy(config_yaml_path, wandb_path)

    if len(os.listdir(checkpoint_path)) > 0:
        print("Load checkpoint from path: %s" % checkpoint_path)
        restore_step, n_step = get_restore_step(checkpoint_path)
        resume_from_checkpoint = os.path.join(checkpoint_path, restore_step)
        print("Resume from checkpoint", resume_from_checkpoint)
    elif config_reload_from_ckpt is not None:
        resume_from_checkpoint = config_reload_from_ckpt
        print("Reload ckpt specified in the config file %s" % resume_from_checkpoint)
    else:
        print("Train from scratch")
        resume_from_checkpoint = None

    latent_diffusion = instantiate_from_config(configs["model"])
    latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name)

    guidance_scale = configs["model"]["params"]["evaluation_params"][
        "unconditional_guidance_scale"
    ]
    ddim_sampling_steps = configs["model"]["params"]["evaluation_params"][
        "ddim_sampling_steps"
    ]
    n_candidates_per_samples = configs["model"]["params"]["evaluation_params"][
        "n_candidates_per_samples"
    ]

    checkpoint = torch.load(resume_from_checkpoint)
    latent_diffusion.load_state_dict(checkpoint["state_dict"])

    latent_diffusion.eval()
    latent_diffusion = latent_diffusion.cuda()

    latent_diffusion.generate_sample(
        val_loader,
        unconditional_guidance_scale=guidance_scale,
        ddim_steps=ddim_sampling_steps,
        n_gen=n_candidates_per_samples,
        visualize_attn=True,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-c",
        "--config_yaml",
        type=str,
        required=False,
        help="path to config .yaml file",
    )

    parser.add_argument(
        "-l",
        "--list_inference",
        type=str,
        required=False,
        help="The filelist that contain captions (and optionally filenames)",
    )
    parser.add_argument(
        "-reload_from_ckpt",
        "--reload_from_ckpt",
        type=str,
        required=False,
        help="the checkpoint path for the model",
    )

    args = parser.parse_args()

    assert torch.cuda.is_available(), "CUDA is not available"

    config_yaml = args.config_yaml
    # dataset_json = build_dataset_json_from_list(args.list_inference)
    dataset_json = None
    exp_name = os.path.basename(config_yaml.split(".")[0])
    exp_group_name = os.path.basename(os.path.dirname(config_yaml))

    config_yaml_path = os.path.join(config_yaml)
    config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader)

    if args.reload_from_ckpt != None:
        config_yaml["reload_from_ckpt"] = args.reload_from_ckpt

    infer(dataset_json, config_yaml, config_yaml_path, exp_group_name, exp_name)
