import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
from torchvision import datasets, transforms

NORMALIZE = True
DDPM_SAMPLING = False


#############################################
# Helper function 
#############################################

def match_last_dims(data, size):
    """
    Repeat a 1D tensor so that its last dimensions [1:] match `size[1:]`.
    Useful for working with batched data.
    """
    assert len(data.size()) == 1, "Data must be 1-dimensional (one value per batch)"
    for _ in range(len(size) - 1):
        data = data.unsqueeze(-1)
    return data.repeat(1, *(size[1:]))

class DiffusionProcess:
    def __init__(self, 
                 device=None, 
                 T=1.0, 
                 process_type='VP', 
                 schedule='cosine',
                 rescale_timesteps=True,
                 learn_variance=False,
                 **kwargs):
        """
        process_type: 'VP' (variance preserving) or 'VE' (variance exploding)
        schedule: for VP, choose 'linear' or 'cosine'
        T: time horizon (used to sample t ~ Uniform(0,T)); the neural net always receives normalized time in [0,1]
        """
        self.device = 'cpu' if device is None else device
        self.T = T
        self.process_type = process_type
        self.schedule = schedule
        self.learn_variance = learn_variance
        self.rescale_timesteps = rescale_timesteps
        
        assert self.rescale_timesteps == True, "rescale_timesteps must be set to True"
        assert self.learn_variance == False, "learn_variance is not implemented yet"
        
        if process_type == 'VP':
            if schedule == 'linear':
                self.beta_min = kwargs.get('beta_min', 0.1)
                self.beta_max = kwargs.get('beta_max', 20.0)
            elif schedule == 'cosine':
                # s is a small offset (default 0.008) as in Nichol & Dhariwal’s cosine schedule.
                self.s = kwargs.get('s', 0.008)
            else:
                raise ValueError("Unknown VP schedule type")
        elif process_type == 'VE':
            self.sigma_min = kwargs.get('sigma_min', 0.01)
            self.sigma_max = kwargs.get('sigma_max', 50.0)
        else:
            raise ValueError("Unknown process type: {}".format(process_type))
    
    def alpha_bar(self, t_norm):
        """
        For VP processes, returns the cumulative product (or survival probability) at normalized time t_norm.
        For linear: ᾱ(t) = exp( - [β_min t + 0.5 (β_max - β_min)t^2] )
        For cosine: ᾱ(t) = cos( ((t + s)/(1+s))*(π/2) )^2
        """
        if self.process_type == 'VP':
            if self.schedule == 'linear':
                integrated_beta = self.beta_min * t_norm + 0.5 * (self.beta_max - self.beta_min) * t_norm**2
                return torch.exp(-integrated_beta)
            elif self.schedule == 'cosine':
                return torch.cos((t_norm + self.s) / (1 + self.s) * (torch.pi / 2))**2
        else:
            return None

    def sigma_fn(self, t_norm):
        """
        For VE processes, returns the noise scale at normalized time t_norm.
        Using an exponential schedule: σ(t) = σ_min * (σ_max/σ_min)^t
        """
        if self.process_type == 'VE':
            return self.sigma_min * (self.sigma_max / self.sigma_min)**(t_norm)
        else:
            return None
        
    def get_timesteps(self, N):
        return torch.linspace(self.T, 1e-4, N + 1)
    
    def get_score_from_eps(self, eps, t):
        """
        Given the predicted noise eps, returns the score (i.e. ∇_x log p_t(x)).
        For VP: score = - (predicted noise) / sqrt(1 - ᾱ(t))
        For VE: score = - (predicted noise) / σ(t)
        """
        t_norm = t / self.T
        if self.process_type == 'VP':
            alpha_bar = self.alpha_bar(t_norm).view(-1, *([1] * (eps.dim() - 1)))
            score = -eps / torch.sqrt(1 - alpha_bar)
            return score
        elif self.process_type == 'VE':
            sigma_t = self.sigma_fn(t_norm).view(-1, *([1] * (eps.dim() - 1)))
            score = -eps / sigma_t
            return score
    
    def score_fn(self, model, x, t, **model_kwargs):
        """
        Given the noise-predicting model, returns the score (i.e. ∇_x log p_t(x))
        at actual time t. Note that the model expects a normalized time (t/T).
        For VP: score = - (predicted noise) / sqrt(1 - ᾱ(t))
        For VE: score = - (predicted noise) / σ(t)
        """
        t_norm = t / self.T  # normalize to [0,1]
        if self.process_type == 'VP':
            alpha_bar = self.alpha_bar(t_norm).view(-1, *([1] * (x.dim() - 1)))
            epsilon = model(x, t_norm, **model_kwargs) #.view(-1, 1))
            score = -epsilon / torch.sqrt(1 - alpha_bar)
            return score
        elif self.process_type == 'VE':
            sigma_t = self.sigma_fn(t_norm).view(-1, *([1] * (x.dim() - 1)))
            epsilon = model(x, t_norm, **model_kwargs)#.view(-1, 1))
            score = -epsilon / sigma_t
            return score

    def forward(self, x_start, t_norm):
        """
        Forward (diffusion) process: given a clean sample x_start and time t (in [0,T]),
        returns the noised version x_t.
        For VP: x_t = sqrt(ᾱ(t)) x_start + sqrt(1-ᾱ(t)) noise
        For VE: x_t = x_start + σ(t)*noise
        """
        noise = torch.randn_like(x_start)
        if self.process_type == 'VP':
            alpha_bar = self.alpha_bar(t_norm).view(-1, *([1] * (x_start.dim() - 1)))
            x_t = torch.sqrt(alpha_bar) * x_start + torch.sqrt(1 - alpha_bar) * noise # torch.sqrt(1 - alpha_bar**2) vs torch.sqrt(1 - alpha_bar)...
        elif self.process_type == 'VE':
            sigma_t = self.sigma_fn(t_norm).view(-1, *([1] * (x_start.dim() - 1)))
            x_t = x_start + sigma_t * noise
        return x_t, noise

    def training_losses(self, models, x_start, model_kwargs=None, normalize = False, **kwargs):
        """
        Training loss for the diffusion process.
        Samples t ~ Uniform(0, T), applies the forward process, and then
        computes the MSE loss between the network’s predicted noise and the true noise.
        """
        model = models['default']
        x_start = x_start.to(self.device)
        
        if NORMALIZE:
            x_start = 2*x_start - 1 # normalize to [-1, 1]
        
        batch_size = x_start.size(0)
        # Sample t uniformly from [0, T]
        t = torch.rand(batch_size, device=self.device) * (self.T - 1e-4) + 1e-5
        
        t_norm = t / self.T
        x_t, noise = self.forward(x_start, t_norm)
        
        if model_kwargs is None:
            model_kwargs = {}
        # The model takes x_t and normalized time t_norm
        predicted_noise = model(x_t, t_norm, **model_kwargs)
        loss = F.mse_loss(predicted_noise, noise, reduction='none')
        
        score_loss = F.mse_loss(
            self.get_score_from_eps(predicted_noise, t), 
            self.get_score_from_eps(noise, t)
            )
                
        return {'loss': loss, 'score_loss': score_loss}

    def test_losses(self, models, x_start, model_kwargs=None, **kwargs):
        model = models['default']
        x_start = x_start.to(self.device)
        
        if NORMALIZE:
            x_start = 2*x_start - 1 # normalize to [-1, 1]
        
        batch_size = x_start.size(0)
        
        if model_kwargs is None:
            model_kwargs = {}
        
        def monte_carlo_avg():
            # Sample t uniformly from [0, T]
            t = torch.rand(batch_size, device=self.device) * self.T
            t_norm = t / self.T
            # Compute the forward process at time t (which gives x_t and the ground-truth noise)
            x_t, noise = self.forward(x_start, t_norm)
            # Pass x_t and normalized time through the model to predict the noise component
            predicted_noise = model(x_t, t_norm, **model_kwargs)
            # Compute the MSE loss for noise prediction at time t
            loss_t = F.mse_loss(predicted_noise, noise)
            # Compute the score loss at time t.
            score_loss_t = F.mse_loss(
                self.get_score_from_eps(predicted_noise, t),
                self.get_score_from_eps(noise, t)
            )
            # Return the loss and score loss
            return {'loss': loss_t, 'score_loss': score_loss_t}
        
        total_loss = 0.0
        total_score_loss = 0.0
        num_mc = 5
        # Perform Monte Carlo averaging over multiple samples
        for i in range(num_mc):
            # Compute the average loss over multiple samples
            loss = monte_carlo_avg()
            total_loss += loss['loss']
            total_score_loss += loss['score_loss']
        # Normalize by the number of samples
        total_loss /= num_mc
        total_score_loss /= num_mc
        
        return {'loss': total_loss, 'score_loss': total_score_loss}
    
    def sample(self, 
               models, 
               shape, 
               reverse_steps=1000, 
               epsilon = 1e-4,
               guidance_scale = 1.0,
               get_sample_history=False, 
               progress=True, 
               deterministic = False,
               clip_denoised = False,
               model_kwargs = None,
               ei_integrator = False,
               **kwargs):
        """
        SDE sampling using the Euler–Maruyama method to solve the reverse-time SDE:
          dx = [f(x,t) - g(t)^2 * score(x,t)] dt + g(t) dẆ
        For VP:
          f(x,t) = -0.5 β(t)x  and  g(t) = sqrt(β(t))
          where β(t) is given either by a linear or cosine schedule.
        For VE:
          f(x,t) = 0  and  g(t) = σ(t)
        """
        
        assert not (DDPM_SAMPLING and (self.process_type == 'VE')), "DDPM sampling is only implemented for VP processes"
        model = models['default']
        if model_kwargs is None:
            model_kwargs = {}
            
            
        # Initialize x_T (the prior sample)
        if self.process_type == 'VP':
            xt = torch.randn(shape, device=self.device)
        elif self.process_type == 'VE':
            t_norm = torch.tensor(1.0, device=self.device)
            sigma_T = self.sigma_fn(t_norm).view(*([1] * len(shape)))
            xt = sigma_T * torch.randn(shape, device=self.device)
        
        samples = []
        model.eval()
        
        
        t_seq = torch.linspace(self.T*0.995,
                               epsilon,
                               reverse_steps + 1,
                               device=self.device)
        
        if self.schedule == 'linear':
            get_beta_t = lambda t_norm_batch : self.beta_min + t_norm_batch * (self.beta_max - self.beta_min)
        elif self.schedule == 'cosine':
            get_beta_t = lambda t_norm_batch : torch.clamp(
                torch.pi / (self.T * (1 + self.s)) * torch.tan(((t_norm_batch + self.s) / (1 + self.s)) * (torch.pi / 2)),
                0., 500.)
        else:
            raise ValueError("Unknown schedule type: {}".format(self.schedule))
        
        score_fn = self.score_fn
        
        with torch.inference_mode():

            
            # Create a time discretization from T to 0
            if progress:
                progress_bar = tqdm(range(reverse_steps))
            else:
                progress_bar = range(reverse_steps)
            for i in progress_bar:
                t_current = t_seq[i]
                t_next = t_seq[i + 1]
                dt = t_next - t_current  # dt is negative (reverse time)
                # Create a batch of current time values for the update.
                t_batch = torch.full((shape[0],), t_current, device=self.device)
                t_next_batch = torch.full((shape[0],), t_next, device=self.device)
                t_norm_batch = t_batch / self.T
                t_next_norm_batch = t_next_batch / self.T


                if DDPM_SAMPLING:
                    
                    alpha_bar_t = self.alpha_bar(t_norm_batch)
                    alpha_bar_t_next = self.alpha_bar(t_next_norm_batch)
                    beta_t = 1 - alpha_bar_t / alpha_bar_t_next
                    beta_tilde_t = beta_t * (1 - alpha_bar_t_next) / (1 - alpha_bar_t) 
                    beta_t = match_last_dims(beta_t, xt.shape)
                    beta_tilde_t = match_last_dims(beta_tilde_t, xt.shape)
                    alpha_bar_t = match_last_dims(alpha_bar_t, xt.shape)
                    alpha_bar_t_next = match_last_dims(alpha_bar_t_next, xt.shape)
                    
                    alpha_t = 1 - beta_t
                    
                    eps_pred = model(xt, t_batch / self.T, **model_kwargs)
                    
                    if deterministic:
                        # xt = (xt - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_t) \
                        #     + torch.sqrt(1 - alpha_bar_t_next) * eps_pred
                        # 1. reconstruct x0  (use ᾱ_t in the denominator!)
                        x0_pred = (xt - torch.sqrt(1.0 - alpha_bar_t) * eps_pred) \
                                / torch.sqrt(alpha_bar_t)

                        # 2. deterministic DDIM (η = 0, σ_t = 0)
                        xt = torch.sqrt(alpha_bar_t_next) * x0_pred + \
                            torch.sqrt(1.0 - alpha_bar_t_next) * eps_pred

                    else:
                        
                        xt = (xt - beta_t * eps_pred / torch.sqrt(1 - alpha_bar_t)) / torch.sqrt(alpha_t)
                        
                        # add noise if i < reverse_steps - 1
                        if i < (reverse_steps - 1):
                            z = torch.randn_like(xt)
                            xt = xt + torch.sqrt(beta_tilde_t) * z
                    
                else:
                    
                    if self.process_type == 'VE':
                        sigma_t = self.sigma_fn(t_norm_batch).view(-1, *([1] * (xt.dim() - 1)))
                        f = 0.0
                        g = sigma_t
                    elif self.process_type == 'VP':
                        beta_t = get_beta_t(t_norm_batch).view(-1, *([1] * (xt.dim() - 1)))
                        f = - 0.5 * beta_t * xt
                        g = torch.sqrt(beta_t)
                    
                    # Get the score (using the noise-predicting network)
                    score = self.score_fn(model, xt, t_batch, **model_kwargs)
                    
                    if ei_integrator and self.process_type == 'VP':
                        # Apply Euler exponential integrator update
                        h = (-dt)  # step size is positive
                        exp_neg_h = torch.exp(- 0.5 * beta_t * h)
                        one_minus_exp_neg_2h = 1 - torch.exp(- beta_t * h)

                        xt = exp_neg_h * xt + one_minus_exp_neg_2h * score
                        if not deterministic:
                            z = torch.randn_like(xt)
                            xt += torch.sqrt(one_minus_exp_neg_2h) * z
                    else:
                        # Euler–Maruyama update:
                        #   x = x + [f - g^2 * score] dt + g * sqrt(-dt) * z,   where z ~ N(0, I)
                        if not deterministic:
                            z = torch.randn_like(xt)
                            xt = xt + (f - (g**2) * score) * dt + g * torch.sqrt(-dt) * z
                        else:
                            xt = xt + (f - (g**2) * score / 2) * dt
                        if get_sample_history:
                            samples.append(xt.clone())
            
        if NORMALIZE:
            xt = (xt + 1) / 2  # rescale to [0, 1]
        if clip_denoised:
            xt = torch.clamp(xt, 0.0, 1.0) if NORMALIZE else torch.clamp(xt, -1.0, 1.0)
        return xt if not get_sample_history else torch.stack(samples)



        # we will compute the whole integral over [0, T], using an integral solver
        
        # choose my own time subdivision
        # timesteps_0 = torch.linspace(1e-6, self.T / 10000, 10, device=self.device)
        # timesteps_1 = torch.linspace(self.T / 10000, self.T / 100, 10, device=self.device)
        # timesteps_2 = torch.linspace(self.T / 100, self.T, 50, device=self.device)
        # timesteps = torch.cat([timesteps_0, timesteps_1, timesteps_2])
        
        # total_loss = 0.0
        # total_score_loss = 0.0

        # for i in range(timesteps.shape[0] - 1):
        #     t = timesteps[i]
        #     dt = timesteps[i + 1] - timesteps[i]
            
        #     # t = (i + 0.5) * dt  # current time value in [0, T]
            
        #     t_norm = torch.full((batch_size,), t / self.T, device=self.device)
            
        #     # Compute the forward process at time t (which gives x_t and the ground-truth noise)
        #     x_t, noise = self.forward(x_start, t_norm)
            
        #     if model_kwargs is None:
        #         model_kwargs = {}
            
        #     # Pass x_t and normalized time through the model to predict the noise component
        #     predicted_noise = model(x_t, t_norm, **model_kwargs)
            
        #     # Compute the MSE loss for noise prediction at time t
        #     loss_t = F.mse_loss(predicted_noise, noise)
            
        #     # Compute the score loss at time t.
        #     # Note: Here we convert the constant time t into a tensor of the same shape as the batch.
        #     t_tensor = torch.full((batch_size,), t, device=self.device)
        #     score_loss_t = F.mse_loss(
        #         self.get_score_from_eps(predicted_noise, t_tensor),
        #         self.get_score_from_eps(noise, t_tensor)
        #     )
            
        #     # Add the loss weighted by the time step (dt) to approximate the integral
        #     total_loss += loss_t * dt
        #     total_score_loss += score_loss_t * dt
        
        # return {'loss': total_loss, 'score_loss': total_score_loss}
        
        
        
        # if True:
        #         class ODESystem(torch.nn.Module):
        #             def forward(self, t, x):
        #                 t_actual = t  # t ∈ [0,1] corresponds to [0,T]
        #                 # get device of x
        #                 score = score_fn(model, x, t_actual * torch.ones(x.shape[0], device = x.device), **model_kwargs)
        #                 beta_t = get_beta_t(t)
        #                 drift = -0.5 * beta_t * x - 0.5 * beta_t * score
        #                 return drift

        #         # Solve from t=1 to t=0 with adaptive Dopri5
        #         solution = odeint(
        #             ODESystem().to(self.device),
        #             xt,
        #             torch.tensor([0.995, 0.0001], device=self.device),  # [start, end] times
        #             method='rk4',  # 'dopri5' is also available
        #             # rtol=1e-3,  # Looser tolerance → fewer steps
        #             # atol=1e-3,
        #             # options={
        #             #     'max_num_steps': 4 * reverse_steps,  # Allow some overshoot
        #             #     'first_step': 1.0/ reverse_steps,  # Initial guess
        #             # }
        #             options={
        #                 'step_size': 1.0 / reverse_steps,  # Initial guess
        #             }
        #         )
        #         xt = solution[-1]
            
        #     else: