# Based on https://github.com/Jackson-Kang/Pytorch-Diffusion-Model-Tutorial/tree/main

import torch
from torch import nn
import torch.nn.functional as F

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(
        ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


'''
Implement the Gaussian diffusion process.

Args:
    model - denoising neural network
    image resolution - dimension sof input image
    n_times - number of steps in the diffusion process
    beta_minmax - start and end of beta variance schedule
    device - where data and model are located

'''
class Diffusion(nn.Module):
    def __init__(self, model, image_resolution=[1,9,9], n_times=1000, beta_minmax=[1e-4, 2e-2], device='cuda',
                    target='noise'):
        
        super(Diffusion, self).__init__()

        if target != 'noise':
            raise NotImplementedError # not sure atm how to implement x0 prediction with guidance

        self.n_times = n_times
        self.img_C, self.img_H, self.img_W = image_resolution
        self.model = model
        self.device = device
        self.target = target

        # define linear variance schedule (beta_t)
        betas = cosine_beta_schedule(self.n_times).float().to(device)
        self.sqrt_betas = torch.sqrt(betas)

        # define alpha for forward diffusion kernel
        self.alphas = 1 - betas
        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1-self.alpha_bars)
        self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
        self.prev_alpha_bars = F.pad(self.alpha_bars[:-1], (1,0), value=self.alpha_bars[0])
        self.sqrt_prev_alpha_bars = torch.sqrt(self.alpha_bars)

    '''
    The extract function allows us to extract the appropriate element from alpha, located at index
    t, for a batch of indices.

    Inputs:
        a - alpha
        t - timestep
        x_shape - shape of x at time t
    '''
    def extract(self, a, t, x_shape):
        b, *_ = t.shape
        out = a.gather(-1,t)
        return out.reshape(b, *((1,) * (len(x_shape)-1)))

    '''
    Normalize x, which seems to be crucial to train the reverse process network according to
    Ho (2020).
    '''
    def scale_to_minus_one_to_one(self, x):
        return x * 2 - 1

    def reverse_scale_to_zero_to_one(self, x):
        return (x + 1) * 0.5

    '''
    Implement forward diffusion process, perturbing x_0 to x_t. Follows equation 4 from DDPM
    paper [Ho 2020].

    Inputs:
        x_zeros - x_0 samples
        t - time in the forward diffusion process we wish to perturb to
    '''
    def make_noisy(self, x_zeros, t):
        epsilon = torch.randn_like(x_zeros).to(self.device)

        sqrt_alpha_bar = self.extract(self.sqrt_alpha_bars, t, x_zeros.shape)
        sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars, t, x_zeros.shape)

        noisy_sample = x_zeros * sqrt_alpha_bar + epsilon * sqrt_one_minus_alpha_bar

        return noisy_sample.detach(), epsilon

    '''
    Implement forward process and epsilon prediction, wherein a data point is perturbed, and
    the denoising network predicts the amount of epsilon that must be removed in the current
    diffusion step to eventuall obtain x_0 again.

    Inputs:
        x_zeros - x_0 samples
    '''
    def forward(self, x_zeros, x_c, start_age, age_deltas, pattern):
        x_zeros = self.scale_to_minus_one_to_one(x_zeros)
        x_c = self.scale_to_minus_one_to_one(x_c)
        pattern = self.scale_to_minus_one_to_one(pattern)

        B, _, _, _ = x_zeros.shape

        # randomly choose diffusion timestep
        t = torch.randint(low=0, high=self.n_times, size=(B,)).long().to(self.device)

        # perturb x_zeros with fixed variance schedule (forward diffusion)
        perturbed_images, epsilon = self.make_noisy(x_zeros, t)

        # predict the epsilon (noise) to remove given perturbed data at diffusion timestep t
        pred = self.model(perturbed_images, x_c, start_age, age_deltas, pattern, t)

        if self.target == 'noise':
            return perturbed_images, epsilon, pred#, pred_x0, weights
        else:
            return perturbed_images, x_zeros, pred

    '''
    Implement denoising process, following lines 2-3 of Algorithm 2 from DDPM [Ho 2020].
    
    Inputs:
        x_t - noisy sample of x at time t
        timestep - the time step t, broadcasted to size (B,)
        t - time step in reverse diffusion process
    Ouputs:
        sample after one step of denoising
    '''
    def denoise_at_t(self, x_t, x_c, start_age, age_deltas, pattern, null_pattern, w, timestep, t):
        if t > 1:
            z = torch.randn_like(x_t).to(self.device)
        else:
            z = torch.zeros_like(x_t).to(self.device)

        if self.target == 'noise':
            # at inference, use predicted noise (epsilon) to restore perturbed data sample
            epsilon_pred_pattern = self.model(x_t, x_c, start_age, age_deltas, pattern, timestep)
            epsilon_pred_null = self.model(x_t, x_c, start_age, age_deltas, null_pattern, timestep)
            epsilon_pred = w.view(-1,1,1,1) * epsilon_pred_pattern + (1-w).view(-1,1,1,1) * epsilon_pred_null

            alpha = self.extract(self.alphas, timestep, x_t.shape)
            sqrt_alpha = self.extract(self.sqrt_alphas, timestep, x_t.shape)
            sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars, timestep, x_t.shape)
            sqrt_beta = self.extract(self.sqrt_betas, timestep, x_t.shape)

            # denoise at time t, utilizing predicted noise
            x_t_minus_1 = 1 / sqrt_alpha * (x_t - (1-alpha)/sqrt_one_minus_alpha_bar*epsilon_pred) + sqrt_beta*z
        else:
            x0_pred = self.model(x_t, timestep)
            alpha = self.extract(self.alphas, timestep, x_t.shape)
            sqrt_alpha = self.extract(self.sqrt_alphas, timestep, x_t.shape)
            alpha_bar = self.extract(self.alpha_bars, timestep, x_t.shape)
            prev_alpha_bar = self.extract(self.prev_alpha_bars, timestep, x_t.shape)
            sqrt_prev_alpha_bar = self.extract(self.sqrt_prev_alpha_bars, timestep, x_t.shape)
            sqrt_beta = self.extract(self.sqrt_betas, timestep, x_t.shape)

            x_t_minus_1 = (1 - prev_alpha_bar)*sqrt_alpha*x_t / (1-alpha_bar) + (1 - alpha)*sqrt_prev_alpha_bar*x0_pred / (1 - alpha_bar) + sqrt_beta*z   
        
        return x_t_minus_1.clamp(-1., 1)

    '''
    Implement sampling process, following Algorithm 2 from DDPM [Ho 2020].
    Conditioning augmentation guided by LDM: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddpm.py
    
    Inputs:
        cond - conditioning input
    Outputs:
        x0 - denoised samples
    '''
    def sample(self, n_samples, x_c, start_age, age_deltas, pattern, null_pattern, w):
        pattern = self.scale_to_minus_one_to_one(pattern)
        null_pattern = self.scale_to_minus_one_to_one(null_pattern)
        x_c = self.scale_to_minus_one_to_one(x_c)

        # prepare to generate 5 samples for each of B data points
        B = x_c.shape[0]
        batch_size = B * n_samples
        x_c_exp = x_c.repeat_interleave(n_samples, dim=0)
        start_age_exp = start_age.repeat_interleave(n_samples, dim=0)
        age_deltas_exp = age_deltas.repeat_interleave(n_samples, dim=0)
        pattern_exp = pattern.repeat_interleave(n_samples, dim=0)
        null_pattern_exp = null_pattern.repeat_interleave(n_samples, dim=0)
        w_exp = w.repeat_interleave(n_samples, dim=0)

        # start from random noise vector, x_T
        x_t = torch.randn((batch_size, self.img_C, self.img_H, self.img_W)).to(self.device)

        # autoregressively denoise from x_T to x_0
        for t in range(self.n_times-1, -1, -1):
            timestep = torch.tensor([t]).repeat_interleave(batch_size, dim=0).long().to(self.device)
            x_t = self.denoise_at_t(x_t, x_c_exp, start_age_exp, age_deltas_exp, pattern_exp, null_pattern_exp, w_exp, timestep, t)

        # denormalize x_0 into [0,1] range
        x_0 = self.reverse_scale_to_zero_to_one(x_t)

        return x_0.view(B, n_samples, self.img_C, self.img_H, self.img_W)
    
class Diffusion_MoCap(nn.Module):
    def __init__(self, model, image_resolution=[1,9,9], n_times=1000, beta_minmax=[1e-4, 2e-2], device='cuda',
                    target='noise'):
        
        super(Diffusion_MoCap, self).__init__()

        if target != 'noise':
            raise NotImplementedError # not sure atm how to implement x0 prediction with guidance

        self.n_times = n_times
        self.img_C, self.img_H, self.img_W = image_resolution
        self.model = model
        self.device = device
        self.target = target

        # define linear variance schedule (beta_t)
        betas = cosine_beta_schedule(self.n_times).float().to(device)
        self.sqrt_betas = torch.sqrt(betas)

        # define alpha for forward diffusion kernel
        self.alphas = 1 - betas
        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1-self.alpha_bars)
        self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
        self.prev_alpha_bars = F.pad(self.alpha_bars[:-1], (1,0), value=self.alpha_bars[0])
        self.sqrt_prev_alpha_bars = torch.sqrt(self.alpha_bars)

    '''
    The extract function allows us to extract the appropriate element from alpha, located at index
    t, for a batch of indices.

    Inputs:
        a - alpha
        t - timestep
        x_shape - shape of x at time t
    '''
    def extract(self, a, t, x_shape):
        b, *_ = t.shape
        out = a.gather(-1,t)
        return out.reshape(b, *((1,) * (len(x_shape)-1)))

    '''
    Normalize x, which seems to be crucial to train the reverse process network according to
    Ho (2020).
    '''
    def scale_to_minus_one_to_one(self, x):
        return x * 2 - 1

    def reverse_scale_to_zero_to_one(self, x):
        return (x + 1) * 0.5

    '''
    Implement forward diffusion process, perturbing x_0 to x_t. Follows equation 4 from DDPM
    paper [Ho 2020].

    Inputs:
        x_zeros - x_0 samples
        t - time in the forward diffusion process we wish to perturb to
    '''
    def make_noisy(self, x_zeros, t):
        epsilon = torch.randn_like(x_zeros).to(self.device)

        sqrt_alpha_bar = self.extract(self.sqrt_alpha_bars, t, x_zeros.shape)
        sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars, t, x_zeros.shape)

        noisy_sample = x_zeros * sqrt_alpha_bar + epsilon * sqrt_one_minus_alpha_bar

        return noisy_sample.detach(), epsilon

    '''
    Implement forward process and epsilon prediction, wherein a data point is perturbed, and
    the denoising network predicts the amount of epsilon that must be removed in the current
    diffusion step to eventuall obtain x_0 again.

    Inputs:
        x_zeros - x_0 samples
    '''
    def forward(self, x_zeros, x_c, rot, pattern):
        x_zeros = self.scale_to_minus_one_to_one(x_zeros)
        x_c = self.scale_to_minus_one_to_one(x_c)
        pattern = self.scale_to_minus_one_to_one(pattern)

        B, _, _ = x_zeros.shape

        # randomly choose diffusion timestep
        t = torch.randint(low=0, high=self.n_times, size=(B,)).long().to(self.device)

        # perturb x_zeros with fixed variance schedule (forward diffusion)
        perturbed_images, epsilon = self.make_noisy(x_zeros, t)

        # predict the epsilon (noise) to remove given perturbed data at diffusion timestep t
        pred = self.model(perturbed_images, x_c, rot, pattern, t)

        if self.target == 'noise':
            return perturbed_images, epsilon, pred#, pred_x0, weights
        else:
            return perturbed_images, x_zeros, pred

    '''
    Implement denoising process, following lines 2-3 of Algorithm 2 from DDPM [Ho 2020].
    
    Inputs:
        x_t - noisy sample of x at time t
        timestep - the time step t, broadcasted to size (B,)
        t - time step in reverse diffusion process
    Ouputs:
        sample after one step of denoising
    '''
    def denoise_at_t(self, x_t, x_c, rot, pattern, null_pattern, w, timestep, t):
        if t > 1:
            z = torch.randn_like(x_t).to(self.device)
        else:
            z = torch.zeros_like(x_t).to(self.device)

        if self.target == 'noise':
            # at inference, use predicted noise (epsilon) to restore perturbed data sample
            epsilon_pred_pattern = self.model(x_t, x_c, rot, pattern, timestep)
            epsilon_pred_null = self.model(x_t, x_c, rot, null_pattern, timestep)
            epsilon_pred = w.view(-1,1,1) * epsilon_pred_pattern + (1-w).view(-1,1,1) * epsilon_pred_null

            alpha = self.extract(self.alphas, timestep, x_t.shape)
            sqrt_alpha = self.extract(self.sqrt_alphas, timestep, x_t.shape)
            sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars, timestep, x_t.shape)
            sqrt_beta = self.extract(self.sqrt_betas, timestep, x_t.shape)

            # denoise at time t, utilizing predicted noise
            x_t_minus_1 = 1 / sqrt_alpha * (x_t - (1-alpha)/sqrt_one_minus_alpha_bar*epsilon_pred) + sqrt_beta*z
        else:
            x0_pred = self.model(x_t, timestep)
            alpha = self.extract(self.alphas, timestep, x_t.shape)
            sqrt_alpha = self.extract(self.sqrt_alphas, timestep, x_t.shape)
            alpha_bar = self.extract(self.alpha_bars, timestep, x_t.shape)
            prev_alpha_bar = self.extract(self.prev_alpha_bars, timestep, x_t.shape)
            sqrt_prev_alpha_bar = self.extract(self.sqrt_prev_alpha_bars, timestep, x_t.shape)
            sqrt_beta = self.extract(self.sqrt_betas, timestep, x_t.shape)

            x_t_minus_1 = (1 - prev_alpha_bar)*sqrt_alpha*x_t / (1-alpha_bar) + (1 - alpha)*sqrt_prev_alpha_bar*x0_pred / (1 - alpha_bar) + sqrt_beta*z   
        
        return x_t_minus_1.clamp(-1., 1)

    '''
    Implement sampling process, following Algorithm 2 from DDPM [Ho 2020].
    Conditioning augmentation guided by LDM: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddpm.py
    
    Inputs:
        cond - conditioning input
    Outputs:
        x0 - denoised samples
    '''
    def sample(self, n_samples, x_c, rot, pattern, null_pattern, w):
        _, n_horizon, flat_dim = pattern.shape
        _, _, rot_dim = rot.shape
        pattern = self.scale_to_minus_one_to_one(pattern)
        null_pattern = self.scale_to_minus_one_to_one(null_pattern)
        x_c = self.scale_to_minus_one_to_one(x_c)

        # prepare to generate 5 samples for each of B data points
        B = x_c.shape[0]
        batch_size = B * n_samples
        x_c_exp = x_c.repeat_interleave(n_samples, dim=0)
        rot_exp = rot.repeat_interleave(n_samples, dim=0)
        pattern_exp = pattern.repeat_interleave(n_samples, dim=0)
        null_pattern_exp = null_pattern.repeat_interleave(n_samples, dim=0)
        w_exp = w.repeat_interleave(n_samples, dim=0)

        # start from random noise vector, x_T
        x_t = torch.randn((batch_size, n_horizon, flat_dim+rot_dim)).to(self.device)

        # autoregressively denoise from x_T to x_0
        for t in range(self.n_times-1, -1, -1):
            timestep = torch.tensor([t]).repeat_interleave(batch_size, dim=0).long().to(self.device)
            x_t = self.denoise_at_t(x_t, x_c_exp, rot_exp, pattern_exp, null_pattern_exp, w_exp, timestep, t)

        # denormalize x_0 into [0,1] range
        x_0 = self.reverse_scale_to_zero_to_one(x_t)

        return x_0.view(B, n_samples, n_horizon, flat_dim+rot_dim)