"""
Generate a large batch of samples from a super resolution model, given a batch
of samples from a regular model from image_sample.py.
"""

import argparse
import os

import blobfile as bf
import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F

import matplotlib.pyplot as plt
plt.ion()

from improved_diffusion import dist_util, logger
from improved_diffusion.image_datasets import load_pair_data, load_paired_mat_data, load_paired_npy_data
from improved_diffusion.script_util import (
    sr_model_and_diffusion_defaults,
    sr_create_model_and_diffusion,
    args_to_dict,
    add_dict_to_argparser,
)


def main():
    args = create_argparser().parse_args()
    os.makedirs(args.save_dir, exist_ok=True)

    dist_util.setup_dist()
    logger.configure()

    logger.log("creating model...")
    model, diffusion = sr_create_model_and_diffusion(
        **args_to_dict(args, sr_model_and_diffusion_defaults().keys()),
        in_channels=3, out_channels=3
    )
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    model.eval()

    logger.log("loading data...")
    data = load_lowres_data(args.input_dir, args.target_dir, args.batch_size, args.large_size, args.small_size, args.class_cond)

    logger.log("creating samples...")
    all_inputs, all_outputs, all_targets = [], [], []
    while len(all_outputs) * args.batch_size < args.num_samples:
        batch, cond, model_kwargs = next(data)
        all_inputs.append(((cond + 1) * 127.5).clamp(0, 255).to(th.uint8).permute(0, 2, 3, 1).squeeze().cpu().numpy())
        all_targets.append(((batch + 1) * 127.5).clamp(0, 255).to(th.uint8).permute(0, 2, 3, 1).squeeze().cpu().numpy())
        cond = cond.to(dist_util.dev())
        model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
        sample = diffusion.p_sample_loop(
            model,
            cond,
            (args.batch_size, 3, args.large_size, args.large_size),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
        )
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1).squeeze()
        sample = sample.contiguous()

        all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(all_samples, sample)  # gather not supported with NCCL
        for sample in all_samples:
            all_outputs.append(sample.cpu().numpy())
        logger.log(f"created {len(all_outputs) * args.batch_size} samples")

    inp = np.concatenate(all_inputs, axis=0)
    out = np.concatenate(all_outputs, axis=0)
    tag = np.concatenate(all_targets, axis=0)
    inp = inp[: args.num_samples]
    out = out[: args.num_samples]
    tag = tag[: args.num_samples]
    if dist.get_rank() == 0:
        shape_str = "x".join([str(x) for x in out.shape])
        out_path = os.path.join(args.save_dir, f"samples_{shape_str}.npz")
        logger.log(f"saving to {out_path}")
        np.savez(out_path, out)
        for i in range(args.num_samples):
            plt.imsave(os.path.join(args.save_dir, '%d_LR.png'%i), inp[i])
            plt.imsave(os.path.join(args.save_dir, '%d_SR.png'%i), out[i])
            plt.imsave(os.path.join(args.save_dir, '%d_HR.png'%i), tag[i])

    dist.barrier()
    logger.log("sampling complete")


def load_lowres_data(input_dir, target_dir, batch_size, large_size, small_size, class_cond=False):
    data = load_pair_data(
        input_dir=input_dir,
        target_dir=target_dir,
        batch_size=batch_size,
        image_size=large_size,
        class_cond=class_cond,
        deterministic=True,
    )
    yield from data

def create_argparser():
    defaults = dict(
        input_dir="",  # ENTER YOUR INPUT IMAGE DIRECTORY HERE
        target_dir="",  # ENTER YOUR TARGET IMAGE DIRECTORY HERE
        clip_denoised=False,
        num_samples=30,  # ADJUST THE TOTAL NUMBER OF SAMPLES HERE
        batch_size=15,  # ADJUST BATCH SIZE ACCORDING TO YOUR HARDWARE
        use_ddim=False,
        model_dir="models",
        model_name="BBDM-20231109-2245",  # DESIGNATE THE MODEL NAME TO TEST
        ckpt="ema_0.9999_050000.pt",  # SIGNATE THE CHECKPOINT TO TEST
        save_dir="outputs"
    )
    defaults['save_dir'] = os.path.join(defaults['save_dir'], defaults['model_name'])
    defaults['model_path'] = os.path.join(defaults['model_dir'], defaults['model_name'], defaults['ckpt'])
    defaults.update(sr_model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()
