"""
Like image_sample.py, but use a noisy image classifier to guide the sampling
process towards more realistic images.
"""

import argparse
import os

import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F

from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    classifier_defaults,
    create_model_and_diffusion,
    create_classifier,
    create_imagenet_classifier,
    add_dict_to_argparser,
    args_to_dict,
)

import torch
from torch import nn
torch.backends.cudnn.benchmark=True
torch.set_float32_matmul_precision('medium')
import torch.utils.checkpoint
from torchvision.utils import make_grid, save_image 

def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    logger.configure()

    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(torch.load('workdirs/256x256_diffusion_uncond.pt'))
    model.to(dist_util.dev())
    model.convert_to_fp16()
    model.eval()

    logger.log("loading classifier...")
    classifier = create_imagenet_classifier(args.denoise_augment)
    if args.denoise_augment:
        classifier.load_state_dict(
            dist_util.load_state_dict('workdirs/im_clf_DA/model050500.pt', map_location="cpu", logger=logger)
        )
    else:
        classifier.load_state_dict(
            dist_util.load_state_dict('workdirs/256x256_classifier.pt', map_location="cpu", logger=logger)
        )
    classifier.to(dist_util.dev())
    classifier.convert_to_fp16()
    classifier.eval()
    
    argsdict = args_to_dict(args, model_and_diffusion_defaults().keys())
    argsdict['class_cond'] = True
    argsdict['timestep_respacing'] = ''
    model_cond, diffusion_unspaced = create_model_and_diffusion(
        **argsdict
    )
    model_cond.to(dist_util.dev())
    model_cond.load_state_dict(torch.load('workdirs/256x256_diffusion.pt'))
    model_cond = model_cond.eval()
    model_cond.convert_to_fp16()

    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
            _, variance, _ = diffusion_unspaced.q_mean_variance(x, t)

            x_in = x.detach().requires_grad_(True)
            if args.denoise_augment:
                score_unc = model(x_in, t)[:,:3]
                std = variance**0.5
                denoise_x = x_in - std*score_unc 
                inputs = torch.cat([x_in,denoise_x],dim=1)
            else:
                inputs = x_in
            logits = classifier(inputs, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

    def model_fn(x, t, y=None):
        assert y is not None
        return model_cond(x, t, y)

    logger.log("sampling...")
    filename = f'{logger.get_dir()}/{args.timestep_respacing}_{args.denoise_augment}_{args.classifier_scale}_{args.ID}'
    if os.path.exists(filename):
        all_images, all_labels = torch.load(filename)
        all_samples = None
    else:
        all_images = []
        all_labels = []
        all_samples = None
    
    while len(all_images) * args.batch_size < args.num_samples:        
        model_kwargs = {}
        classes = th.randint(
            low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
        )
        model_kwargs["y"] = classes
        sample_fn = (
            diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
        )
        sample = sample_fn(
            model_fn,
            (args.batch_size, 3, args.image_size, args.image_size),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            cond_fn=cond_fn,
            device=dist_util.dev(),
            progress=True
        )
        if all_samples is None:
            all_samples = 0.5*(sample+1).detach().cpu()
        else:
            all_samples = torch.cat([all_samples,0.5*(sample+1).detach().cpu()])
        if len(all_images)%5 == 0:
            save_image(all_samples,f"{filename}.png")
        if all_samples.shape[0]>100:
            all_samples = all_samples[-100:]
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()

        all_images.append(sample.cpu().numpy())
        all_labels.extend(classes.cpu().numpy().reshape(-1,1))
        if len(all_images)%15 == 0:
            torch.save((all_images,all_labels),filename)
        logger.log(f"created {len(all_images) * args.batch_size} samples")
        # break
    arr = np.concatenate(all_images, axis=0)
    arr = arr[: args.num_samples]
    label_arr = np.concatenate(all_labels, axis=0)
    label_arr = label_arr[: args.num_samples]
    if dist.get_rank() == 0:
        shape_str = "x".join([str(x) for x in arr.shape])
        out_path = f"{filename}.npz"
        logger.log(f"saving to {out_path}")
        np.savez(out_path, arr, label_arr)

    dist.barrier()
    logger.log("sampling complete")


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
        classifier_path="",
        denoise_augment=False,
        classifier_scale=1.0,
        ID=0
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(classifier_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()