# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for DiffPure. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------

import os
import random

import torch
import torchvision.utils as tvu

from defence.guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults


class GuidedDiffusion(torch.nn.Module):
    def __init__(self, args, config, device=None, model_dir='pretrained/guided_diffusion'):
        super().__init__()
        self.args = args
        self.config = config
        if device is None:
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.device = device

        # load model
        model_config = model_and_diffusion_defaults()
        model_config.update(vars(self.config.model))
        print(f'model_config: {model_config}')
        model, diffusion = create_model_and_diffusion(**model_config)
        model.load_state_dict(torch.load(f'256x256_diffusion_uncond.pt', map_location='cpu'))
        model.eval().to(self.device)

        if model_config['use_fp16']:
            model.convert_to_fp16()

        self.model = model
        self.diffusion = diffusion
        self.betas = torch.from_numpy(diffusion.betas).float().to(self.device)

    def image_editing_sample(self, img, bs_id=0, tag=None):
        assert isinstance(img, torch.Tensor)
        batch_size = img.shape[0]
        if tag is None:
            tag = 'rnd' + str(random.randint(0, 10000))
        #out_dir = os.path.join(self.args.log_dir, 'bs' + str(bs_id) + '_' + tag)
        assert img.ndim == 4, img.ndim
        img = img.to(self.device)
        x0 = img
        #if bs_id < 2:
        #    os.makedirs(out_dir, exist_ok=True)
        #    tvu.save_image((x0 + 1) * 0.5, os.path.join(out_dir, f'original_input.png'))
        xs = []
        for it in range(self.args.sample_step):
            e = torch.randn_like(x0)
            total_noise_levels = self.args.t
            a = (1 - self.betas).cumprod(dim=0)
            x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt()
            #if bs_id < 2:
            #    tvu.save_image((x + 1) * 0.5, os.path.join(out_dir, f'init_{it}.png'))
            for i in reversed(range(total_noise_levels)):
                t = torch.tensor([i] * batch_size, device=self.device)
                x = self.diffusion.p_sample(self.model, x, t,
                                            clip_denoised=True,
                                            denoised_fn=None,
                                            cond_fn=None,
                                            model_kwargs=None)["sample"]
                # added intermediate step vis
                #if (i - 99) % 100 == 0 and bs_id < 2:
                #    tvu.save_image((x + 1) * 0.5, os.path.join(out_dir, f'noise_t_{i}_{it}.png'))
            x0 = x
            #if bs_id < 2:
            #    torch.save(x0, os.path.join(out_dir, f'samples_{it}.pth'))
            #    tvu.save_image((x0 + 1) * 0.5, os.path.join(out_dir, f'samples_{it}.png'))
            xs.append(x0)
        return torch.cat(xs, dim=0)

