"""
Like image_sample.py, but use a universal(clean) image classifier to guide the sampling
process towards more realistic images.
"""
import sys, os
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), "latent_guided_diffusion"))
import argparse
import ctypes
from time import time, localtime, strftime

import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from functools import partial
from torchvision.utils import make_grid, save_image

from ldm.models.diffusion.ddim_with_grad import DDIMSamplerWithGrad

from guided_diffusion.transfer_learning import create_pretrained_model
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
    pretrained_ImageNet_defaults,
)

class OptimizerDetails:
    def __init__(self):
        self.num_recurrences = None
        self.operation_func = None
        self.optimizer = None # handle it on string level
        self.lr = None
        self.loss_func = None
        self.backward_steps = 0
        self.loss_cutoff = None
        self.lr_scheduler = None
        self.warm_start = None
        self.old_img = None
        self.fact = 0.5
        self.print = False
        self.print_every = None
        self.folder = None
        self.tv_loss = None
        self.use_forward = False
        self.forward_guidance_wt = 0
        self.other_guidance_func = None
        self.other_criterion = None
        self.original_guidance = False
        self.sampling_type = None
        self.loss_save = None


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    logger.configure(args.log_dir)
    
    time_tag = strftime("%m%d_%I:%M:%S", localtime(time()))

    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )

    logger.log("loading reward...")


    reward_list = []
    for i, reward_path in enumerate(args.reward_paths):
        kwargs = args_to_dict(args, pretrained_ImageNet_defaults().keys())
        reward = create_pretrained_model(**kwargs)
        reward.load_state_dict(
            dist_util.load_state_dict(reward_path, map_location="cpu")
        )
        reward.to(dist_util.dev())
        reward.eval()
        reward_list.append(reward)


    log_sigmoid = th.nn.LogSigmoid()
    


    def model_fn(x, t, y=None, args=None, model=None):
        return model(x, t, y if args.class_cond else None)

    def operation_func(x, t=None):
        if t == None:
            return [reward(x) for reward in reward_list]
        return [reward(x, t) for reward in reward_list]

    def loss_func(reward_vals, *args):
        return sum([-log_sigmoid(rval) for rval in reward_vals]) / len(reward_list)

    ##### operation #####
    operation = OptimizerDetails()
    operation.num_recurrences = args.num_recurrences
    operation.operation_func = operation_func
    operation.other_guidance_func = None

    operation.optimizer = 'Adam'
    operation.lr = args.optim_lr 
    operation.loss_func = loss_func
    operation.other_criterion = None

    operation.backward_steps = args.backward_steps
    operation.loss_cutoff = args.optim_loss_cutoff # 0.00001
    operation.tv_loss = args.optim_tv_loss

    operation.use_forward = args.use_forward 
    operation.forward_guidance_wt = args.forward_guidance_wt

    operation.original_guidance = args.original_guidance
    operation.sampling_type = args.sampling_type

    operation.warm_start = args.optim_warm_start #False
    operation.print = args.optim_print
    operation.print_every = 10
    operation.folder = logger.get_dir() # results_folder
    if args.optim_print:
        os.makedirs(f'{operation.folder}/samples', exist_ok=True)
    operation.Aug = args.optim_aug
    #####################

    logger.log("sampling... ")
    all_images = []
    
    while len(all_images) * args.batch_size < args.num_samples:
        """
        See https://github.com/arpitbansal297/Universal-Guided-Diffusion/blob/b3af48f78d7bec105f3ea1579faf8602c520ed1e/Guided_Diffusion_Imagenet/Guided/helpers.py#L245
        """
        diffusion = diffusion.to(dist_util.dev())
        sample_fn = DDIMSamplerWithGrad(diffusion).sample_operation
        with diffusion.ema_scope():
            # LSUN
            samples_ddim, _ = sample_fn(
                400, # custom_steps
                args.batch_size,
                ( 
                    diffusion.model.diffusion_model.in_channels,
                    diffusion.model.diffusion_model.image_size,
                    diffusion.model.diffusion_model.image_size,
                ),
                operated_image=None,
                operation=operation,
                eta=0.,
                verbose=False,
            )
        sample = diffusion.decode_first_stage(samples_ddim)
    
        
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()

        gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
        logger.log(f"created {len(all_images) * args.batch_size} samples")

    hparam_tag = f"[bwd{args.backward_steps}_lr_{args.optim_lr}]"
    if args.use_forward:
        hparam_tag = hparam_tag + f"_[fwd_wt_{args.forward_guidance_wt}]"
    if args.original_guidance:
        hparam_tag = hparam_tag + f"_[org_wt_{args.original_guidance_wt}]"

    arr = np.concatenate(all_images, axis=0)
    arr = arr[: args.num_samples]
    if dist.get_rank() == 0:
        shape_str = "x".join([str(x) for x in arr.shape])
        out_path = os.path.join(logger.get_dir(), f"samples_{args.expr_name}_{hparam_tag}_{time_tag}_{shape_str}.npz")
        logger.log(f"saving to {out_path}")
        np.savez(out_path, arr)
        

    logger.log("sampling complete")

    sample_tensor = th.tensor(np.transpose(arr.astype(np.float32) / 255., (0, 3, 1, 2)))
    sample_save_dir = os.path.join(args.log_dir, f"sample_imgs_{args.expr_name}_{hparam_tag}_{time_tag}")
    if not os.path.isdir(sample_save_dir) and dist.get_rank() == 0:
        os.makedirs(sample_save_dir)

    dist.barrier()
    if dist.get_rank() == 0:
        for i in range(sample_tensor.size(0) // 100):
            img_grid = make_grid(sample_tensor[100*i: 100*(i + 1)], nrow=10)
            save_image(img_grid, os.path.join(sample_save_dir, f"batch_{i}.png"))


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        temp_count=0,
        image_channels=1,
        target_class=None,
        log_dir="",
        # operation
        num_recurrences=1,
        optim_lr=0.01,
        backward_steps=0, # multi-gpu sampling is supported only for 0 (otherwise, single gpu sampling is supported)
        optim_loss_cutoff=0.0, 
        optim_tv_loss=False, 
        use_forward=True,
        original_guidance=False,
        original_guidance_wt=0.0,
        forward_guidance_wt=1.0,
        sampling_type='ddpm', # ['ddpm, 'ddim']
        optim_warm_start=False,
        optim_print=False, # save samples
        progressive=False, # print tqdm
        optim_aug=None,
        # latent diffusion
        use_ldm=False,
        expr_name="ensemble",
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(pretrained_ImageNet_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument("--reward_paths", type=str, nargs='+', 
                        help="list of path to each reward model"
    )
    return parser


if __name__ == "__main__":
    main()