"""
Generate PAG dataset
"""

import argparse
import torch
import torch as th

from improved_diffusion import dist_util, logger
from improved_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)

from image_train import load_data


@torch.no_grad()
def main():
    t_step = 550
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    logger.configure()

    logger.log("creating conditional model and diffusion...")
    cond_model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys()))
    cond_model.load_state_dict(
        dist_util.load_state_dict('openai-2022-07-01-12-10-44-324708/model1880000.pt', map_location="cpu")  # trained model
    )
    cond_model.to(dist_util.dev())
    cond_model.eval()

    logger.log("creating unconditional model and diffusion...")
    args.class_cond = True # False
    uncond_model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys()))
    uncond_model.load_state_dict(
        dist_util.load_state_dict('openai-2022-07-01-12-10-44-324708/model1880000.pt', map_location="cpu")
    )
    uncond_model.to(dist_util.dev())
    uncond_model.eval()

    logger.log("sampling...")
    all_images = []
    all_labels = []
    data_iter = load_data(
        data_dir='stl_train',
        batch_size=args.batch_size,
        image_size=64,
        class_cond=True,
    )
    for i in range(5000 // args.batch_size):
        x, y = next(data_iter)
        x = x.cuda()
        processed_batch = [x.unsqueeze(0).detach().cpu()]
        for cls in range(10):
            classes = th.ones(size=(x.shape[0],), device=dist_util.dev()).long() * cls
            no_cls = th.ones(size=(x.shape[0],), device=dist_util.dev()).long() * 11
            model_kwargs = {}
            model_kwargs["y"] = classes
            no_cls_kwargs = {}
            no_cls_kwargs["y"] = no_cls
            x0 = x
            i = t_step
            t = th.tensor([i] * x0.shape[0], device=x0.device).long()
            x = diffusion.q_sample(x0, t)
            a = (x-x0)[0].mean()
            uncond_output = uncond_model(x, diffusion._scale_timesteps(t), **no_cls_kwargs)[:, :3, :, :]
            cond_output = cond_model(x, diffusion._scale_timesteps(t), **model_kwargs)[:, :3, :, :]
            '''
            import matplotlib.pyplot as plt
            idx = 1
            im = (uncond_output[idx])
            im = (im - im.min()) / (im.max() - im.min())
            im = im.permute(1,2,0)
            plt.figure()
            plt.imshow(im.cpu().detach())
            plt.show()
            #
            im = (cond_output[idx])
            im = (im - im.min()) / (im.max() - im.min())
            im = im.permute(1,2,0)
            plt.figure()
            plt.imshow(im.cpu().detach())
            plt.show()
            '''
            uncond_score_unnormalized = -1 * uncond_output
            cond_score_unnormalized = -1 * cond_output
            PAG = cond_score_unnormalized - uncond_score_unnormalized
            PAG = PAG.unsqueeze(0)
            processed_batch.append(PAG.detach().cpu())
        PAG = th.cat(processed_batch).permute(1,0,2,3,4)
        '''
        import matplotlib.pyplot as plt
        idx = 0
        im = (PAG[idx][0])
        im = (im - im.min()) / (im.max() - im.min())
        im = im.permute(1,2,0)
        plt.figure()
        plt.imshow(im.cpu().detach())
        plt.title(y['y'][idx])
        plt.show()
        '''
        all_images.append(PAG)
        all_labels.append(y['y'])
        logger.log(f"created {len(all_images) * args.batch_size} samples")
    arr = torch.cat(all_images, dim=0)
    label_arr = torch.cat(all_labels)
    th.save(arr, f'pag_gt/stl_data_tensor_gt_SM_{t_step}.pt')
    th.save(label_arr, f'pag_gt/stl_label_tensor_gt_SM_{t_step}.pt')
    logger.log("sampling complete")


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()
