"""
    Diffusion model, adapted from: 
    https://github.com/ethz-privsec/diffusion_denoised_smoothing
"""
import os

# torch
import torch
import torch.nn as nn 

from utils.improved_diffusion.cifar10.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
)

class Args:
    image_size=32
    num_channels=128
    num_res_blocks=3
    num_heads=4
    num_heads_upsample=-1
    attention_resolutions="16,8"
    dropout=0.3
    learn_sigma=True
    sigma_small=False
    class_cond=False
    diffusion_steps=4000
    noise_schedule="cosine"
    timestep_respacing=""
    use_kl=False
    predict_xstart=False
    rescale_timesteps=True
    rescale_learned_sigmas=True
    use_checkpoint=False
    use_scale_shift_norm=True


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_use_cuda  = torch.cuda.is_available()
_modelfile = os.path.join('models', 'cifar10', 'denoiser', 'cifar10_uncond_50M_500K.pt')



# ------------------------------------------------------------------------------
#   Denoiser definition
# ------------------------------------------------------------------------------
class DiffusionDenoiser(nn.Module):
    def __init__(self):
        super().__init__()
        model, diffusion = create_model_and_diffusion(
            **args_to_dict(Args(), model_and_diffusion_defaults().keys())
        )
        model.load_state_dict(
            torch.load(_modelfile) if _use_cuda else \
            torch.load(_modelfile, map_location=lambda storage, loc: storage)
        )
        model.eval()
        if _use_cuda: model.cuda()

        self.model = model 
        self.diffusion = diffusion 


    def forward(self, x, t):
        x_in = x * 2 -1
        imgs = self.denoise(x_in, t)
        return imgs


    def denoise(self, x_start, t, multistep=False):
        t_batch = torch.tensor([t] * len(x_start))
        if _use_cuda: t_batch = t_batch.cuda()

        noise = torch.randn_like(x_start)
        if _use_cuda: noise = noise.cuda()

        x_t_start = self.diffusion.q_sample(x_start=x_start, t=t_batch, noise=noise)

        with torch.no_grad():
            if multistep:
                out = x_t_start
                for i in range(t)[::-1]:
                    t_batch = torch.tensor([i] * len(x_start))
                    if _use_cuda: t_batch = t_batch.cuda()
                    out = self.diffusion.p_sample(
                        self.model,
                        out,
                        t_batch,
                        clip_denoised=True
                    )['sample']
            else:
                out = self.diffusion.p_sample(
                    self.model,
                    x_t_start,
                    t_batch,
                    clip_denoised=True
                )['pred_xstart']

        return out
