# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for DiffPure. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------

import os
import random
import numpy as np
import clip
import torch
import torchvision.utils as tvu
import torchsde
import torch.nn.functional as F
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from score_sde.losses import get_optimizer
from score_sde.models import utils as mutils
from score_sde.models.ema import ExponentialMovingAverage
from diffusers.schedulers import DDPMScheduler
from score_sde import sde_lib

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def _extract_into_tensor(arr_or_func, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array or a func.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    if callable(arr_or_func):
        res = arr_or_func(timesteps).float()
    else:
        res = arr_or_func.to(device=device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)


def restore_checkpoint(ckpt_dir, state, device):
    loaded_state = torch.load(ckpt_dir, map_location=device)
    state['optimizer'].load_state_dict(loaded_state['optimizer'])
    state['model'].load_state_dict(loaded_state['model'], strict=False)
    state['ema'].load_state_dict(loaded_state['ema'])
    state['step'] = loaded_state['step']


class RevVPSDE(torch.nn.Module):
    def __init__(self, model, score_type='guided_diffusion', beta_min=0.1, beta_max=20, N=1000,
                 img_shape=(3, 256, 256), model_kwargs=None):
        """Construct a Variance Preserving SDE.

        Args:
          model: diffusion model
          score_type: [guided_diffusion, score_sde, ddpm]
          beta_min: value of beta(0)
          beta_max: value of beta(1)
        """
        super().__init__()
        self.model = model
        self.score_type = score_type
        self.model_kwargs = model_kwargs
        self.img_shape = img_shape

        self.beta_0 = beta_min
        self.beta_1 = beta_max
        self.N = N
        self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
        self.alphas = 1. - self.discrete_betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

        self.alphas_cumprod_cont = lambda t: torch.exp(-0.5 * (beta_max - beta_min) * t**2 - beta_min * t)
        self.sqrt_1m_alphas_cumprod_neg_recip_cont = lambda t: -1. / torch.sqrt(1. - self.alphas_cumprod_cont(t))

        self.noise_type = "diagonal"
        self.sde_type = "ito"

    def _scale_timesteps(self, t):
        assert torch.all(t <= 1) and torch.all(t >= 0), f't has to be in [0, 1], but get {t} with shape {t.shape}'
        return (t.float() * self.N).long()

    def vpsde_fn(self, t, x):
        beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
        drift = -0.5 * beta_t[:, None] * x
        diffusion = torch.sqrt(beta_t)
        return drift, diffusion

    def rvpsde_fn(self, t, x, return_type='drift'):
        """Create the drift and diffusion functions for the reverse SDE"""
        drift, diffusion = self.vpsde_fn(t, x)

        if return_type == 'drift':

            assert x.ndim == 2 and np.prod(self.img_shape) == x.shape[1], x.shape
            x_img = x.view(-1, *self.img_shape)

            if self.score_type == 'guided_diffusion':
                # model output is epsilon
                if self.model_kwargs is None:
                    self.model_kwargs = {}

                disc_steps = self._scale_timesteps(t)  # (batch_size, ), from float in [0,1] to int in [0, 1000]
                model_output = self.model(x_img, disc_steps, **self.model_kwargs)
                # with learned sigma, so model_output contains (mean, val)
                model_output, _ = torch.split(model_output, self.img_shape[0], dim=1)
                assert x_img.shape == model_output.shape, f'{x_img.shape}, {model_output.shape}'
                model_output = model_output.view(x.shape[0], -1)
                score = _extract_into_tensor(self.sqrt_1m_alphas_cumprod_neg_recip_cont, t, x.shape) * model_output

            elif self.score_type == 'score_sde':
                # model output is epsilon
                sde = sde_lib.VPSDE(beta_min=self.beta_0, beta_max=self.beta_1, N=self.N)
                score_fn = mutils.get_score_fn(sde, self.model, train=False, continuous=True)
                score = score_fn(x_img, t)
                assert x_img.shape == score.shape, f'{x_img.shape}, {score.shape}'
                score = score.view(x.shape[0], -1)

            else:
                raise NotImplementedError(f'Unknown score type in RevVPSDE: {self.score_type}!')

            drift = drift - diffusion[:, None] ** 2 * score
            return drift

        else:
            return diffusion

    def f(self, t, x):
        """Create the drift function -f(x, 1-t) (by t' = 1 - t)
            sdeint only support a 2D tensor (batch_size, c*h*w)
        """
        t = t.expand(x.shape[0])  # (batch_size, )
        drift = self.rvpsde_fn(1 - t, x, return_type='drift')
        assert drift.shape == x.shape
        return -drift

    def g(self, t, x):
        """Create the diffusion function g(1-t) (by t' = 1 - t)
            sdeint only support a 2D tensor (batch_size, c*h*w)
        """
        t = t.expand(x.shape[0])  # (batch_size, )
        diffusion = self.rvpsde_fn(1 - t, x, return_type='diffusion')
        assert diffusion.shape == (x.shape[0], )
        return diffusion[:, None].expand(x.shape)


# class RevGuidedDiffusion(torch.nn.Module):
#     def __init__(self, args, config, model, ImageNet=False):
#         super().__init__()
#         self.args = args
#         self.config = config
        
#         self.model = model
#         self.ImageNet = ImageNet
#         if self.ImageNet:
#             self.img_shape = (3, 256, 256)
#         else:
#             self.img_shape = (3, 32, 32)
        
#         self.rev_vpsde = RevVPSDE(model=model, score_type=args.score_type, img_shape=self.img_shape,
#                                   model_kwargs=None).to(device)
#         self.betas = self.rev_vpsde.discrete_betas.float().to(device)

#     def image_editing_sample(self, img):
#         assert isinstance(img, torch.Tensor)
#         batch_size = img.shape[0]
#         state_size = int(np.prod(img.shape[1:]))  # c*h*w

#         assert img.ndim == 4, img.ndim
#         img = img.to(device)
#         x0 = img

#         xs = []
#         for it in range(self.args.sample_step):

#             # e = torch.randn_like(x0).to(device)
#             # total_noise_levels = self.Atack_T
#             # if self.args.rand_t:
#             #     total_noise_levels = self.Atack_T + np.random.randint(-self.args.t_delta, self.args.t_delta)
#             # a = (1 - self.betas).cumprod(dim=0).to(device)
#             # x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt()
            
#             x = x0
            
#             epsilon_dt0, epsilon_dt1 = 0, 1e-5
#             t0, t1 = 1 - self.args.t * 1. / 1000 + epsilon_dt0, 1 - epsilon_dt1
#             t_size = 2
#             ts = torch.linspace(t0, t1, t_size).to(device)

#             x_ = x.view(batch_size, -1)  # (batch_size, state_size)
#             # xs_ = torchsde.sdeint_adjoint(self.rev_vpsde, x_, ts, method='euler')
#             xs_ = torchsde.sdeint(self.rev_vpsde, x_, ts, method='euler')
#             x0 = xs_[-1].view(x.shape)  # (batch_size, c, h, w)

#             xs.append(x0)

#         return torch.cat(xs, dim=0)
def load_clip_model(self):
        model, preprocess = clip.load("ViT-L/14", device='cuda')
        model.eval()
        labels = [
            "a photo of an airplane", "a photo of an automobile", "a photo of a bird",
            "a photo of a cat", "a photo of a deer", "a photo of a dog",
            "a photo of a frog", "a photo of a horse", "a photo of a ship", "a photo of a truck"
        ]
        tokens = clip.tokenize(labels).to('cuda')
        return model, tokens, preprocess


class RevGuidedDiffusion(torch.nn.Module):
    def __init__(self, model, diffusion_step, denoising_step, ImageNet=False):
        super().__init__()
        
        self.diffusion_step = diffusion_step
        self.denoising_step = denoising_step
        
        self.model = model
        self.ImageNet = ImageNet
        if self.ImageNet:
            self.img_shape = (3, 256, 256)
        else:
            self.img_shape = (3, 32, 32)
        
        self.rev_vpsde = RevVPSDE(model=model, score_type='score_sde', img_shape=self.img_shape,
                                  model_kwargs=None).to(device)
        self.betas = self.rev_vpsde.discrete_betas.float().to(device)

    def add_noise(self, image_inputs, num_steps_noise_add):
        image_inputs = F.interpolate(image_inputs, 32, mode='bilinear', align_corners=False)
        image_inputs=((image_inputs - 0.5) * 2).clamp(-1, 1)
        noise = torch.randn_like(image_inputs)
        a = (1 - self.betas).cumprod(dim=0).to(device)
        noisy_image = image_inputs * a[num_steps_noise_add - 1].sqrt() + noise * (1.0 - a[num_steps_noise_add - 1]).sqrt()
        return noisy_image
    
    # def add_noise(self, image_inputs, num_steps_noise_add):
    #     noise = torch.randn_like(image_inputs)
    #     timesteps = torch.LongTensor([num_steps_noise_add])
    #     schedule = DDPMScheduler()
    #     noisy_image = schedule.add_noise(image_inputs, noise, timesteps)
    #     return noisy_image

    def image_editing_sample(self, img):
        assert isinstance(img, torch.Tensor)
        batch_size = img.shape[0]

        assert img.ndim == 4, img.ndim
        img = img.to(device)

        xs = []
        for it in range(self.denoising_step):
                  
            x = self.add_noise(img, self.diffusion_step)
            
            epsilon_dt0, epsilon_dt1 = 0, 1e-5
            t0, t1 = 1 - self.denoising_step * 1. / 1000 + epsilon_dt0, 1 - epsilon_dt1
            t_size = 2
            ts = torch.linspace(t0, t1, t_size).to(device)

            x_ = x.view(batch_size, -1)  # (batch_size, state_size)
            #xs_ = torchsde.sdeint(self.rev_vpsde, x_, ts, method='euler')#dt默认值为0.01
            xs_ = torchsde.sdeint_adjoint(self.rev_vpsde, x_, ts, method='euler')
            x0 = xs_[-1].view(x.shape)  # (batch_size, c, h, w)

            xs.append(x0)

        #return torch.cat(xs, dim=0)
       
        return x0
    