#!/usr/bin/env python
import argparse
import json
import os
from pprint import pprint

import numpy as np
import torch
import torch.distributed as dist
from guided_diffusion import logger
from guided_diffusion.script_util import NUM_CLASSES, add_dict_to_argparser
from PIL import Image

from tada.dist_util import setup_dist
from tada.script_util import (create_model_and_diffusion, find_config,
                              find_env_vars, int2str, set_seed, update_config, save_env_vars)


@torch.no_grad()
def main():
    args, unknown_args = create_argparser().parse_known_args()

    device = setup_dist()
    logger.configure(dir=args.log_dir)

    seed = set_seed(args.seed, max_seed=10000)
    torch.backends.cudnn.benchmark = args.cudnn_benchmark
    torch.backends.cudnn.allow_tf32 = args.use_tf32
    torch.backends.cuda.matmul.allow_tf32 = args.use_tf32

    logger.log("creating model and diffusion...")
    try:
        with open(find_env_vars(args.model_path)) as f:
            env_vars = json.load(f)
            for k, v in env_vars.items():
                os.environ[k] = str(v)
    except FileNotFoundError:
        save_env_vars(os.path.join(logger.get_dir(), "env.json"))
    with open(find_config(args.model_path)) as f:
        config = json.load(f)
    config = update_config(config, args, unknown_args)
    pprint(config, sort_dicts=False)
    model_config = config["model"]
    diffusion_config = config["diffusion"]
    model, diffusion = create_model_and_diffusion(**model_config, **diffusion_config)
    model = load_checkpoint(model, args.model_path)
    model.to(device)
    if model_config["use_fp16"]:
        model.convert_to_fp16()
    model.eval()

    logger.log("sampling...")
    n, rank, world_size = args.start_idx, dist.get_rank(), dist.get_world_size()
    model_iter = str(args.model_path).split("_")
    model_iter = model_iter[-1].split(".")[0]
    all_images = []
    all_labels = []
    use_ddim = diffusion_config["timestep_respacing"].startswith("ddim")
    while len(all_images) * args.batch_size < args.num_samples:
        model_kwargs = {}
        if model_config["class_cond"]:
            classes = torch.randint(
                low=0, high=NUM_CLASSES, size=(args.batch_size,), device=device
            )
            model_kwargs["y"] = classes
        sample_fn = (
            diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
        )
        sample = sample_fn(
            model,
            (args.batch_size, 3, model_config["image_size"], model_config["image_size"]),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
        )
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()

        if args.save_png:
            dirname = os.path.join(logger.get_dir(), "png" + model_iter)
            filename = os.path.join(dirname, "sample_{:05d}.png")
            os.makedirs(dirname, exist_ok=True)
            np_sample = sample.cpu().numpy()
            for i in range(np_sample.shape[0]):
                idx = n + (i + rank*args.batch_size)
                if idx >= args.num_samples:
                    break
                Image.fromarray(np_sample[i], "RGB").save(filename.format(idx))
            n += args.batch_size * world_size
        dist.barrier()

        gathered_samples = [torch.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
        if model_config["class_cond"]:
            gathered_labels = [
                torch.zeros_like(classes) for _ in range(dist.get_world_size())
            ]
            dist.all_gather(gathered_labels, classes)
            all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
        logger.log(f"created {len(all_images) * args.batch_size} samples")

    arr = np.concatenate(all_images, axis=0)
    arr = arr[: args.num_samples]
    if model_config["class_cond"]:
        label_arr = np.concatenate(all_labels, axis=0)
        label_arr = label_arr[: args.num_samples]
    if dist.get_rank() == 0:
        model_name = os.path.splitext(os.path.basename(args.model_path))[0]
        n_str = f"n={int2str(args.num_samples, 'k')}"
        if diffusion_config["timestep_respacing"] != "":
            t_str = f"t={diffusion_config['timestep_respacing']}"
        else:
            t_str = f"t={diffusion_config['diffusion_steps']}"
        sample_file = f"{model_name}_{n_str}_{t_str}_seed={seed:04d}.npz"

        out_path = os.path.join(logger.get_dir(), sample_file)
        logger.log(f"saving to {out_path}")
        if model_config["class_cond"]:
            np.savez(out_path, arr, label_arr)
        else:
            np.savez(out_path, arr)

    dist.barrier()
    logger.log("sampling complete")


def load_checkpoint(model, checkpoint_path, device="cpu"):
    state_dict = torch.load(checkpoint_path, map_location=device)
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    if len(missing_keys) > 0:
        logger.log(f"Missing keys when loading pretrained model: {missing_keys}")
    if len(unexpected_keys) > 0:
        logger.log(f"Unexpected keys when loading pretrained model: {unexpected_keys}")
    return model


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
        log_dir="./samples",
        cudnn_benchmark=True,
        use_tf32=True,
        save_png=False,
        seed=None,
        start_idx=0,
    )
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()
