"""
Like image_sample.py, but use a universal(clean) image classifier to guide the sampling
process towards more realistic images.
"""

import argparse
import os

from time import time, localtime, strftime
import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from functools import partial
from torchvision.utils import make_grid, save_image

from guided_diffusion.transfer_learning import create_pretrained_model
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    classifier_defaults,
    create_model_and_diffusion,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
    pretrained_ImageNet_defaults,
)



class OptimizerDetails:
    def __init__(self):
        self.num_recurrences = None
        self.operation_func = None
        self.optimizer = None # handle it on string level
        self.lr = None
        self.loss_func = None
        self.backward_steps = 0
        self.loss_cutoff = None
        self.lr_scheduler = None
        self.warm_start = None
        self.old_img = None
        self.fact = 0.5
        self.print = False
        self.print_every = None
        self.folder = None
        self.tv_loss = None
        self.use_forward = False
        self.forward_guidance_wt = 0
        self.other_guidance_func = None
        self.other_criterion = None
        self.original_guidance = False
        self.sampling_type = None
        self.loss_save = None


def main():
    start = strftime("%m%d_%I:%M:%S", localtime(time()))
    args = create_argparser().parse_args()

    dist_util.setup_dist()

    log_dir = os.path.join(args.log_dir,args.model_name+f"_{args.temp_count}")

    logger.configure(log_dir)
    reward_tag = args.reward_path.split('/')[-1].split('_')[-1]

    logger.log(vars(args))
    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("loading reward...")

    reward = create_pretrained_model(
        **args_to_dict(args, pretrained_ImageNet_defaults().keys())
    )
    if "ensemble" not in args.model_name:
        reward.load_state_dict(
            dist_util.load_state_dict(args.reward_path, map_location="cpu")
        )
    reward.to(dist_util.dev())
    reward.eval()
    
    log_sigmoid = th.nn.LogSigmoid()
    
    # def cond_fn(x, t, y=None):
    #     assert y is not None
    #     with th.enable_grad():
    #         x_in = x.detach().requires_grad_(True)
    #         logits = classifier(x_in, t)
    #         log_probs = log_sigmoid(logits)
    #         return th.autograd.grad(log_probs.sum(), x_in)[0] * args.classifier_scale

    def model_fn(x, t, y=None, args=None, model=None):
        return model(x, t, y if args.class_cond else None)

    ##### operation #####
    operation = OptimizerDetails()
    operation.num_recurrences = args.num_recurrences
    operation.operation_func = reward
    operation.other_guidance_func = None

    operation.optimizer = 'Adam'
    operation.lr = args.optim_lr 
    operation.loss_func = lambda x, y: -log_sigmoid(x)
    operation.other_criterion = None

    operation.backward_steps = args.backward_steps
    operation.loss_cutoff = args.optim_loss_cutoff # 0.00001
    operation.tv_loss = args.optim_tv_loss

    operation.use_forward = args.use_forward 
    operation.forward_guidance_wt = args.forward_guidance_wt

    operation.original_guidance = args.optim_original_guidance
    operation.sampling_type = args.sampling_type

    operation.warm_start = args.optim_warm_start #False
    operation.print = args.optim_print
    operation.print_every = 10
    operation.folder = logger.get_dir() # results_folder
    if args.optim_print:
        os.makedirs(f'{operation.folder}/samples', exist_ok=True)
    operation.Aug = args.optim_aug
    #####################

    logger.log("sampling... ")
    all_images = []
    # all_labels = []
    while len(all_images) * args.batch_size < args.num_samples:
        """
        See https://github.com/arpitbansal297/Universal-Guided-Diffusion/blob/b3af48f78d7bec105f3ea1579faf8602c520ed1e/Guided_Diffusion_Imagenet/Guided/helpers.py#L245
        """
        model_kwargs = {}
        if args.target_class is None:
            classes = th.randint(
                low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
            )
        else:
            classes = int(args.target_class) * th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype=th.int64)
        model_kwargs["y"] = classes

        sample_fn = diffusion.ddim_sample_loop_operation
        sample = sample_fn(
            partial(model_fn, model=model, args=args),
            (args.batch_size, args.image_channels, args.image_size, args.image_size), # self.shape,
            operated_image=None, 
            operation=operation,
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            cond_fn=None,
            # cond_fn=partial(cond_fn, classifier=classifier, args=args),
            device=dist_util.dev(),
            progress=args.progressive
        )
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()

        gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
        # gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())]
        # dist.all_gather(gathered_labels, classes)
        # all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
        logger.log(f"created {len(all_images) * args.batch_size} samples")

    hparam_tag = f"[bwd{args.backward_steps}_lr_{args.optim_lr}_fwd_wt_{args.forward_guidance_wt}]"
    arr = np.concatenate(all_images, axis=0)
    arr = arr[: args.num_samples]
    # label_arr = np.concatenate(all_labels, axis=0)
    # label_arr = label_arr[: args.num_samples]
    if dist.get_rank() == 0:
        shape_str = "x".join([str(x) for x in arr.shape])
        out_path = os.path.join(logger.get_dir(), f"samples_{reward_tag}_{hparam_tag}_{shape_str}.npz")
        logger.log(f"saving to {out_path}")
        np.savez(out_path, arr)
        # np.savez(out_path, arr, label_arr)

    logger.log("sampling complete")

    sample_tensor = th.tensor(np.transpose(arr.astype(np.float32) / 255., (0, 3, 1, 2)))
    sample_save_dir = os.path.join(args.log_dir, f"sample_imgs_{reward_tag}_{hparam_tag}")
    if not os.path.isdir(sample_save_dir) and dist.get_rank() == 0:
        os.makedirs(sample_save_dir)

    dist.barrier()
    end = strftime("%m%d_%I:%M:%S", localtime(time()))
    logger.log(f"start : {start}, end : {end}")
    if dist.get_rank() == 0:
        for i in range(sample_tensor.size(0) // 100):
            img_grid = make_grid(sample_tensor[100*i: 100*(i + 1)], nrow=10)
            save_image(img_grid, os.path.join(sample_save_dir, f"batch_{i}.png"))


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        model_path="",
        reward_path="",
        # classifier_scale=1.0,
        image_channels=1,
        target_class=None,
        log_dir="",
        # operation
        num_recurrences=1,
        optim_lr=0.01,
        backward_steps=0, # multi-gpu sampling is supported only for 0 (otherwise, single gpu sampling is supported)
        optim_loss_cutoff=0.0, 
        optim_tv_loss=False, 
        use_forward=True,
        optim_original_guidance=False,
        forward_guidance_wt=1.0,
        sampling_type='ddpm', 
        optim_warm_start=False,
        optim_print=False, # save samples
        progressive=False, # print tqdm
        optim_aug=None,
        temp_count=0,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(pretrained_ImageNet_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()