"""
Like image_sample.py, but use a universal(clean) image classifier to guide the sampling
process towards more realistic images.
"""

import argparse
import os
from time import time, localtime, strftime

import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from functools import partial
from torchvision.utils import make_grid, save_image

from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    classifier_defaults,
    create_model_and_diffusion,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
)


class OptimizerDetails:
    def __init__(self):
        self.num_recurrences = None
        self.operation_func = None
        self.optimizer = None # handle it on string level
        self.lr = None
        self.loss_func = None
        self.backward_steps = 0
        self.loss_cutoff = None
        self.lr_scheduler = None
        self.warm_start = None
        self.old_img = None
        self.fact = 0.5
        self.print = False
        self.print_every = None
        self.folder = None
        self.tv_loss = None
        self.use_forward = False
        self.forward_guidance_wt = 0
        self.other_guidance_func = None
        self.other_criterion = None
        self.original_guidance = False
        self.sampling_type = None
        self.loss_save = None


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    logger.configure(args.log_dir)
    time_tag = strftime("%m%d_%I:%M:%S", localtime(time()))

    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("loading reward...")
    reward_args = args_to_dict(args, classifier_defaults().keys())
    reward_args["output_dim"] = 1
    reward = create_classifier(**reward_args)
    reward.load_state_dict(
        dist_util.load_state_dict(args.reward_path, map_location="cpu")
    )
    reward.to(dist_util.dev())
    if args.classifier_use_fp16:
        reward.convert_to_fp16()
    reward.eval()

    if args.target_class is not None and args.classifier_scale > 1e-6:
        logger.log("loading classifier for class-oriented guidance...")
        classifier_args = args_to_dict(args, classifier_defaults().keys())
        classifier_args["output_dim"] = args.classifier_output_dim
        classifier = create_classifier(**classifier_args)
        classifier.load_state_dict(
            dist_util.load_state_dict(args.classifier_path, map_location="cpu")
        )
        classifier.to(dist_util.dev())
        if args.classifier_use_fp16:
            classifier.convert_to_fp16()
        classifier.eval()

    log_sigmoid = th.nn.LogSigmoid()
    
    def cond_fn(x, t, y=None):
        with th.enable_grad():
            x_in = x.detach().requires_grad_(True)
            out = log_sigmoid(reward(x_in, t)).flatten() * args.original_guidance_wt
        
            if args.classifier_scale > 1e-6:
                logits = classifier(x_in, t)
                log_probs = F.log_softmax(logits, dim=-1)
                selected = log_probs[range(len(logits)), y.view(-1)]
                out += selected * args.classifier_scale

            return th.autograd.grad(out.sum(), x_in)[0]

    def model_fn(x, t, y=None, args=None, model=None):
        return model(x, t, y if args.class_cond else None)

    ##### operation #####
    operation = OptimizerDetails()
    operation.num_recurrences = args.num_recurrences
    operation.operation_func = reward
    operation.other_guidance_func = None

    operation.optimizer = 'Adam'
    operation.lr = args.optim_lr 
    operation.loss_func = lambda x, y: -log_sigmoid(x)
    operation.other_criterion = None

    operation.backward_steps = args.backward_steps
    operation.loss_cutoff = args.optim_loss_cutoff # 0.00001
    operation.tv_loss = args.optim_tv_loss

    operation.use_forward = args.use_forward 
    operation.forward_guidance_wt = args.forward_guidance_wt

    operation.original_guidance = args.original_guidance
    operation.sampling_type = args.sampling_type

    operation.warm_start = args.optim_warm_start #False
    operation.print = args.optim_print
    operation.print_every = 10
    operation.folder = logger.get_dir() # results_folder
    if args.optim_print:
        os.makedirs(f'{operation.folder}/samples', exist_ok=True)
    operation.Aug = args.optim_aug
    #####################

    logger.log("sampling... ")
    all_images = []
    # all_labels = []
    while len(all_images) * args.batch_size < args.num_samples:
        """
        See https://github.com/arpitbansal297/Universal-Guided-Diffusion/blob/b3af48f78d7bec105f3ea1579faf8602c520ed1e/Guided_Diffusion_Imagenet/Guided/helpers.py#L245
        """
        model_kwargs = {}
        if args.target_class is None:
            classes = th.randint(
                low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
            )
        else:
            classes = int(args.target_class) * th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype=th.int64)
        model_kwargs["y"] = classes

        sample_fn = diffusion.ddim_sample_loop_operation
        sample = sample_fn(
            partial(model_fn, model=model, args=args),
            (args.batch_size, args.image_channels, args.image_size, args.image_size), # self.shape,
            operated_image=None, 
            operation=operation,
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            cond_fn=cond_fn,
            # cond_fn=partial(cond_fn, classifier=classifier, args=args),
            device=dist_util.dev(),
            progress=args.progressive
        )
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()

        gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
        # gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())]
        # dist.all_gather(gathered_labels, classes)
        # all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
        logger.log(f"created {len(all_images) * args.batch_size} samples")

    hparam_tag = f"[bwd{args.backward_steps}_lr_{args.optim_lr}]"
    if args.use_forward:
        hparam_tag = hparam_tag + f"_[fwd_wt_{args.forward_guidance_wt}]"
    if args.original_guidance:
        hparam_tag = hparam_tag + f"_[org_wt_{args.original_guidance_wt}]"

    arr = np.concatenate(all_images, axis=0)
    arr = arr[: args.num_samples]
    # label_arr = np.concatenate(all_labels, axis=0)
    # label_arr = label_arr[: args.num_samples]
    if dist.get_rank() == 0:
        shape_str = "x".join([str(x) for x in arr.shape])
        out_path = os.path.join(logger.get_dir(), f"samples_{hparam_tag}_{time_tag}_{shape_str}.npz")
        logger.log(f"saving to {out_path}")
        np.savez(out_path, arr)
        # np.savez(out_path, arr, label_arr)

    logger.log("sampling complete")

    sample_tensor = th.tensor(np.transpose(arr.astype(np.float32) / 255., (0, 3, 1, 2)))
    sample_save_dir = os.path.join(args.log_dir, f"sample_imgs_{hparam_tag}_{time_tag}")
    if not os.path.isdir(sample_save_dir) and dist.get_rank() == 0:
        os.makedirs(sample_save_dir)

    dist.barrier()
    if dist.get_rank() == 0:
        for i in range(sample_tensor.size(0) // 100):
            img_grid = make_grid(sample_tensor[100*i: 100*(i + 1)], nrow=10)
            save_image(img_grid, os.path.join(sample_save_dir, f"batch_{i}.png"))


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        model_path="",
        reward_path="",
        classifier_path="",
        classifier_output_dim=1000,
        classifier_scale=0.5,
        image_channels=1,
        target_class=None,
        log_dir="",
        # operation
        num_recurrences=1,
        optim_lr=0.01,
        backward_steps=0, # multi-gpu sampling is supported only for 0 (otherwise, single gpu sampling is supported)
        optim_loss_cutoff=0.0, 
        optim_tv_loss=False, 
        use_forward=True,
        original_guidance=False,
        original_guidance_wt=0.0,
        forward_guidance_wt=1.0,
        sampling_type='ddpm', # ['ddpm, 'ddim']
        optim_warm_start=False,
        optim_print=False, # save samples
        progressive=False, # print tqdm
        optim_aug=None,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(classifier_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()