import torch.nn as nn
import torch
import math
from tqdm import tqdm


class DDPM(nn.Module):
    def __init__(self, data_shape, model, timesteps=1000):
        super().__init__()
        self.timesteps = timesteps
        self.data_shape = data_shape # data_shape without batch

        betas = self._cosine_variance_schedule(timesteps)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=-1)

        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1. - alphas_cumprod))

        self.model = model # Unet(timesteps, time_embedding_dim, in_channels, in_channels, base_dim, dim_mults)

    def forward(self, x, noise):
        # x:NCHW
        t = torch.randint(0, self.timesteps, (x.shape[0],)).to(x.device)
        x_t = self._forward_diffusion(x, t, noise)
        pred_noise = self.model(x_t, t)

        return pred_noise

    @torch.no_grad()
    def sampling(self, n_samples, clipped_reverse_diffusion=True, device="cuda"):
        #x_t = torch.randn((n_samples, self.in_channels, self.image_size, self.image_size)).to(device)
        x_t = torch.randn((n_samples,) + (self.data_shape if isinstance(self.data_shape, tuple) else (self.data_shape,))).to(device)
        for i in tqdm(range(self.timesteps - 1, -1, -1), desc="Sampling"):
            noise = torch.randn_like(x_t).to(device)
            t = torch.tensor([i for _ in range(n_samples)]).to(device)

            if clipped_reverse_diffusion:
                x_t, hat_x_0, pred = self._reverse_diffusion_with_clip_with_pred_mean(x_t, t, noise)
            else:
                x_t, hat_x_0, pred = self._reverse_diffusion_with_pred_mean(x_t, t, noise)

        #x_t = (x_t + 1.) / 2.  # [-1,1] to [0,1]

        return x_t




    def get_log_l_gradient(self, input_vector, loss_function, device):
        input_vector = input_vector.clone().detach().to(device)
        input_vector.requires_grad = True

        loss_function.zero_grad()
        pre_log_loss = loss_function(input_vector.float()) #+ 1e-12

        loss = torch.log(pre_log_loss).mean()
        #loss = pre_log_loss.mean()

        loss.backward()
        nabla_log_l_x = input_vector.grad

        batch_size = input_vector.shape[0]
        nabla_log_l_x = nabla_log_l_x * batch_size

        return nabla_log_l_x

    def importance_samplingdd(self, n_samples, loss_function, clipped_reverse_diffusion=True, device="cuda"):
        derivative_eps = 1e-4
        #x_t = torch.randn((n_samples, self.in_channels, self.image_size, self.image_size)).to(device)
        #x_t = torch.randn((n_samples, self.data_shape)).to(device)
        x_t = torch.randn(
            (n_samples,) + (self.data_shape if isinstance(self.data_shape, tuple) else (self.data_shape,))).to(device)
        for i in tqdm(range(self.timesteps - 1, -1, -1), desc="Sampling"):
            noise = torch.randn_like(x_t).to(device)
            t = torch.tensor([i for _ in range(n_samples)]).to(device)

            original_x_t = x_t.clone().detach().to(device)

            if clipped_reverse_diffusion:
                x_t, pred_original_sample, pred = self._reverse_diffusion_with_clip_with_pred_mean(x_t, t, noise)
            else:
                x_t, pred_original_sample, pred = self._reverse_diffusion_with_pred_mean(x_t, t, noise)

            alpha_t, alpha_t_cumprod, beta_t = self._coefficient(x_t, t)

            dll = self.get_log_l_gradient(pred_original_sample, loss_function, device=device)

            epsilon_output_x_eps_dll = self.model(original_x_t + derivative_eps*dll, t)
            score_x_eps_dll = (0.0 - epsilon_output_x_eps_dll) / torch.sqrt(1.0 - alpha_t_cumprod)
            score = (0.0 - pred) / torch.sqrt(1.0 - alpha_t_cumprod)

            hessian_vector_approx = (score_x_eps_dll - score) / derivative_eps

            total_coefficient = (1.0 - alpha_t) / torch.sqrt(alpha_t) / torch.sqrt(alpha_t_cumprod)
            second_term = total_coefficient*( dll + (1.0-alpha_t_cumprod) * hessian_vector_approx )

            if i < self.timesteps // 2:
                x_t = x_t + second_term


            if clipped_reverse_diffusion:
                x_t = torch.clamp(x_t, -1., 1.)
            else:
                x_t = x_t
        #x_t = (x_t + 1.) / 2.  # [-1,1] to [0,1]

        return x_t



    def importance_sampling(self, n_samples, loss_function, clipped_reverse_diffusion=True, device="cuda"):
        derivative_eps = 1e-8
        #x_t = torch.randn((n_samples, self.in_channels, self.image_size, self.image_size)).to(device)
        #x_t = torch.randn((n_samples, self.data_shape)).to(device)
        x_t = torch.randn(
            (n_samples,) + (self.data_shape if isinstance(self.data_shape, tuple) else (self.data_shape,))).to(device)

        posterior_mean_t_list = []
        x_t_list = []
        x_t_list.append(x_t)

        for i in tqdm(range(self.timesteps - 1, -1, -1), desc="Sampling"):
            noise = torch.randn_like(x_t).to(device)
            t = torch.tensor([i for _ in range(n_samples)]).to(device)

            original_x_t = x_t.clone().detach().to(device)

            if clipped_reverse_diffusion:
                x_t, pred_original_sample, pred = self._reverse_diffusion_with_clip_with_pred_mean(x_t, t, noise)
            else:
                x_t, pred_original_sample, pred = self._reverse_diffusion_with_pred_mean(x_t, t, noise)

            alpha_t, alpha_t_cumprod, beta_t = self._coefficient(x_t, t)

            dll = self.get_log_l_gradient(pred_original_sample, loss_function, device=device)

            epsilon_output_x_eps_dll = self.model(original_x_t + derivative_eps*dll, t)
            score_x_eps_dll = (0.0 - epsilon_output_x_eps_dll) / torch.sqrt(1.0 - alpha_t_cumprod)
            score = (0.0 - pred) / torch.sqrt(1.0 - alpha_t_cumprod)

            hessian_vector_approx = (score_x_eps_dll - score) / derivative_eps

            total_coefficient = (1.0 - alpha_t) / torch.sqrt(alpha_t) / torch.sqrt(alpha_t_cumprod)
            second_term = total_coefficient*( dll + (1.0-alpha_t_cumprod) * hessian_vector_approx )

            if i < self.timesteps * 0.20:
            #if i % 2 == 0:
                x_t = x_t + second_term

            posterior_mean_t = 1.0 / torch.sqrt(alpha_t_cumprod) * (original_x_t + (1.0-alpha_t_cumprod) * score )
            #

            if clipped_reverse_diffusion:
                x_t = torch.clamp(x_t, -1., 1.)
            else:
                x_t = x_t
        #x_t = (x_t + 1.) / 2.  # [-1,1] to [0,1]

            if i % 10 == 0:
                posterior_mean_t_list.append(posterior_mean_t)
                x_t_list.append(x_t)

        return x_t, x_t_list, posterior_mean_t_list


    def _cosine_variance_schedule(self, timesteps, epsilon=0.008):
        steps = torch.linspace(0, timesteps, steps=timesteps + 1, dtype=torch.float32)
        f_t = torch.cos(((steps / timesteps + epsilon) / (1.0 + epsilon)) * math.pi * 0.5) ** 2
        betas = torch.clip(1.0 - f_t[1:] / f_t[:timesteps], 0.0, 0.999)

        return betas

    def _forward_diffusion(self, x_0, t, noise):
        assert x_0.shape == noise.shape

        # Get the number of dimensions for input tensor (excluding batch dimension)
        dims = len(x_0.shape) - 1

        # Create shape tuples for reshaping based on the dimensions of x_0
        shape = [x_0.shape[0]] + [1] * dims  # Keeps batch size and sets other dimensions to 1


        # return self.sqrt_alphas_cumprod.gather(-1, t).reshape(x_0.shape[0], 1, 1, 1) * x_0 + \
        #     self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(x_0.shape[0], 1, 1, 1) * noise
        return self.sqrt_alphas_cumprod.gather(-1, t).reshape(shape) * x_0 + \
            self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(shape) * noise

    @torch.no_grad()
    def _deprecated_reverse_diffusion(self, x_t, t, noise):
        '''
        p(x_{t-1}|x_{t})-> mean,std

        pred_noise-> pred_mean and pred_std
        '''
        pred = self.model(x_t, t)

        alpha_t = self.alphas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        beta_t = self.betas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        mean = (1. / torch.sqrt(alpha_t)) * (x_t - ((1.0 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * pred)

        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(x_t.shape[0], 1, 1, 1)
            std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))
        else:
            std = 0.0

        return mean + std * noise

    @torch.no_grad()
    def _reverse_diffusion_with_pred_mean(self, x_t, t, noise):
        '''
        p(x_{t-1}|x_{t})-> mean, std

        pred_noise -> pred_mean and pred_std
        '''
        pred = self.model(x_t, t)

        # Get the number of dimensions for input tensor (excluding batch dimension)
        dims = len(x_t.shape) - 1

        # Create shape tuples for reshaping based on the dimensions of x_t
        shape = [x_t.shape[0]] + [1] * dims

        # Gather and reshape tensors to match x_t's shape
        alpha_t = self.alphas.gather(-1, t).reshape(shape)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(shape)
        beta_t = self.betas.gather(-1, t).reshape(shape)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(shape)

        # Calculate the mean value for reverse diffusion
        mean = (1. / torch.sqrt(alpha_t)) * (x_t - ((1.0 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * pred)

        # Calculate the standard deviation value for reverse diffusion
        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(shape)
            std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))
        else:
            std = 0.0


        pred_original_sample = (1. / torch.sqrt(alpha_t_cumprod)) * (x_t - sqrt_one_minus_alpha_cumprod_t * pred)

        return mean + std * noise, pred_original_sample, pred


    @torch.no_grad()
    def _reverse_diffusion_multi_component(self, x_t, t, noise):
        '''
        p(x_{t-1}|x_{t})-> mean,std

        pred_noise-> pred_mean and pred_std
        '''
        pred = self.model(x_t, t)

        alpha_t = self.alphas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        beta_t = self.betas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        mean = (1. / torch.sqrt(alpha_t)) * (x_t - ((1.0 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * pred)

        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(x_t.shape[0], 1, 1, 1)
            std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))
        else:
            std = 0.0

        pred_original_sample = (1. / torch.sqrt(alpha_t_cumprod)) * (x_t - sqrt_one_minus_alpha_cumprod_t* pred )

        return mean + std * noise, pred_original_sample

    @torch.no_grad()
    def _coefficient(self, x_t, t):
        '''
        p(x_{t-1}|x_{t})-> mean,std

        pred_noise-> pred_mean and pred_std
        '''
        # Get the number of dimensions for input tensor (excluding batch dimension)
        dims = len(x_t.shape) - 1

        # Create shape tuples for reshaping based on the dimensions of x_t
        shape = [x_t.shape[0]] + [1] * dims

        alpha_t = self.alphas.gather(-1, t).reshape(shape)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(shape)
        beta_t = self.betas.gather(-1, t).reshape(shape)


        return alpha_t, alpha_t_cumprod, beta_t


    @torch.no_grad()
    def _deprecated_reverse_diffusion_with_clip(self, x_t, t, noise):
        '''
        p(x_{0}|x_{t}),q(x_{t-1}|x_{0},x_{t})->mean,std

        pred_noise -> pred_x_0 (clip to [-1.0,1.0]) -> pred_mean and pred_std
        '''
        pred = self.model(x_t, t)
        alpha_t = self.alphas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        beta_t = self.betas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)

        x_0_pred = torch.sqrt(1. / alpha_t_cumprod) * x_t - torch.sqrt(1. / alpha_t_cumprod - 1.) * pred
        x_0_pred.clamp_(-1., 1.)

        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(x_t.shape[0], 1, 1, 1)
            mean = (beta_t * torch.sqrt(alpha_t_cumprod_prev) / (1. - alpha_t_cumprod)) * x_0_pred + \
                   ((1. - alpha_t_cumprod_prev) * torch.sqrt(alpha_t) / (1. - alpha_t_cumprod)) * x_t

            std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))
        else:
            mean = (beta_t / (1. - alpha_t_cumprod)) * x_0_pred  # alpha_t_cumprod_prev=1 since 0!=1
            std = 0.0

        return mean + std * noise

    @torch.no_grad()
    def _reverse_diffusion_with_clip(self, x_t, t, noise):
        '''
        p(x_{0}|x_{t}), q(x_{t-1}|x_{0}, x_{t}) -> mean, std

        pred_noise -> pred_x_0 (clip to [-1.0, 1.0]) -> pred_mean and pred_std
        '''
        pred = self.model(x_t, t)

        # Get the number of dimensions for input tensor (excluding batch dimension)
        dims = len(x_t.shape) - 1

        # Create shape tuples for reshaping based on the dimensions of x_t
        shape = [x_t.shape[0]] + [1] * dims

        # Gather and reshape tensors to match x_t's shape
        alpha_t = self.alphas.gather(-1, t).reshape(shape)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(shape)
        beta_t = self.betas.gather(-1, t).reshape(shape)

        # Calculate x_0 prediction and clip to [-1, 1]
        x_0_pred = torch.sqrt(1. / alpha_t_cumprod) * x_t - torch.sqrt(1. / alpha_t_cumprod - 1.) * pred
        x_0_pred = torch.clamp(x_0_pred, -1., 1.)

        # Calculate mean and std based on the current and previous time step
        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(shape)
            mean = (beta_t * torch.sqrt(alpha_t_cumprod_prev) / (1. - alpha_t_cumprod)) * x_0_pred + \
                   ((1. - alpha_t_cumprod_prev) * torch.sqrt(alpha_t) / (1. - alpha_t_cumprod)) * x_t
            std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))
        else:
            mean = (beta_t / (1. - alpha_t_cumprod)) * x_0_pred
            std = 0.0

        return mean + std * noise

    @torch.no_grad()
    def _reverse_diffusion_with_clip_with_pred_mean(self, x_t, t, noise):
        '''
        p(x_{0}|x_{t}), q(x_{t-1}|x_{0}, x_{t}) -> mean, std

        pred_noise -> pred_x_0 (clip to [-1.0, 1.0]) -> pred_mean and pred_std
        '''
        pred = self.model(x_t, t)

        # Get the number of dimensions for input tensor (excluding batch dimension)
        dims = len(x_t.shape) - 1

        # Create shape tuples for reshaping based on the dimensions of x_t
        shape = [x_t.shape[0]] + [1] * dims

        # Gather and reshape tensors to match x_t's shape
        alpha_t = self.alphas.gather(-1, t).reshape(shape)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(shape)
        beta_t = self.betas.gather(-1, t).reshape(shape)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(shape)

        # Calculate x_0 prediction and clip to [-1, 1]
        x_0_pred = torch.sqrt(1. / alpha_t_cumprod) * x_t - torch.sqrt(1. / alpha_t_cumprod - 1.) * pred
        x_0_pred = torch.clamp(x_0_pred, -1., 1.)

        # Calculate mean and std based on the current and previous time step
        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(shape)
            mean = (beta_t * torch.sqrt(alpha_t_cumprod_prev) / (1. - alpha_t_cumprod)) * x_0_pred + \
                   ((1. - alpha_t_cumprod_prev) * torch.sqrt(alpha_t) / (1. - alpha_t_cumprod)) * x_t
            std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))
        else:
            mean = (beta_t / (1. - alpha_t_cumprod)) * x_0_pred
            std = 0.0

        pred_original_sample = (1. / torch.sqrt(alpha_t_cumprod)) * (x_t - sqrt_one_minus_alpha_cumprod_t * pred)

        return mean + std * noise, pred_original_sample, pred
