import torch
from torch import nn
import numpy as np
from typing import Any, List
from tqdm import tqdm
from scipy.interpolate import interp1d,PchipInterpolator
import copy

from tqdm import tqdm
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from typing import List, Any

import matplotlib.pyplot as plt



import torch
import torch.nn as nn
import numpy as np
import math

BasicContinuousTimeModel(d_model = 128, n_layers = 5, embed_dim = 12, n_res = 1) 

class ContinuousTimeEncoding(nn.Module):
    """
    Continuous Time Encoding using Gaussian Fourier Features.
    This embeds continuous time values t in [0, 1] into a higher-dimensional
    space using sine and cosine functions of different frequencies.
    """
    def __init__(self, embed_dim: int, scale: float = 1.0):
        """
        Args:
            embed_dim (int): The dimensionality of the output embedding.
            scale (float): The scale parameter for the random features. Higher values
                           create embeddings with higher frequency components.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.scale = scale
        # Ensure the scale is broadcastable over the half dimension
        self.W = torch.randn(embed_dim // 2) * self.scale

    def forward(self, t):
        """
        Args:
            t (torch.Tensor): A tensor of time values in [0, 1] of shape (batch_size,).
            
        Returns:
            torch.Tensor: The time encoded tensor of shape (batch_size, embed_dim).
        """
        # Reshape t to (batch_size, 1) and project to high-dimensional space
        t_proj = t.unsqueeze(-1) * self.W.unsqueeze(0) * 2 * np.pi
        # Concatenate sine and cosine embeddings
        return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)

class ContinuousTimeResidualBlock(nn.Module):
    def __init__(self, d_model: int, embed_dim: int):
        super().__init__()
        self.d_model = d_model
        self.emb = ContinuousTimeEncoding(embed_dim=embed_dim, scale=1.0)
        # Adjust the input size for linear layer to d_model + embed_dim
        self.lin1 = nn.Linear(d_model + embed_dim, d_model)  
        self.lin2 = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.act = nn.GELU()

    def forward(self, x, t_emb):
        #t_emb = self.emb(t)
        # Concatenate x and t embeddings correctly
        x_with_t_emb = torch.cat([x, t_emb], dim=-1)  
        # Process through linear layers and activation
        out = self.act(self.lin1(x_with_t_emb)) 
        out = self.act(self.lin2(out)) 
        # Add the output to the original input x, not the re-embedded t
        return out #self.norm(out + x)  # Note that x is the original input to this block

class ResContinuousTimeResidualBlock(nn.Module):

    def __init__(self, d_model: int, embed_dim: int, c: float):
        super().__init__()
        self.d_model = d_model
        self.emb = ContinuousTimeEncoding(embed_dim=embed_dim, scale=1.0)
        # Adjust the input size for linear layer to d_model + embed_dim + 1 (assuming +1 for some additional feature)
        self.lin1 = nn.Linear(d_model + 1 + embed_dim, d_model)  
        self.lin2 = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.act = nn.GELU()

        # Apply default initialization scaled by c for lin1
        n_in_lin1 = d_model + 1 + embed_dim
        n_out_lin1 = d_model
        bound_lin1 = math.sqrt(6 / (n_in_lin1 + n_out_lin1)) * c
        nn.init.uniform_(self.lin1.weight, -bound_lin1, bound_lin1)
        nn.init.zeros_(self.lin1.bias)

        # Apply default initialization scaled by c for lin2
        n_in_lin2 = d_model
        n_out_lin2 = d_model
        bound_lin2 = math.sqrt(6 / (n_in_lin2 + n_out_lin2)) * c
        nn.init.uniform_(self.lin2.weight, -bound_lin2, bound_lin2)
        nn.init.zeros_(self.lin2.bias)


    def forward(self, x,z, t_emb):
        #t_emb = self.emb(t)
        # Concatenate x and t embeddings correctly
        x_with_t_emb = torch.cat([x,z,t_emb], dim=-1)  
        # Process through linear layers and activation
        out = self.act(self.lin1(x_with_t_emb)) 
        out = self.act(self.lin2(out)) 
        # Add the output to the original input x, not the re-embedded t
        return out #self.norm(out + x)  # Note that x is the original input to this block

class BasicContinuousTimeModel(nn.Module):
    def __init__(self, d_model: int = 128, n_layers: int = 2, embed_dim: int = 64, n_res: int = 5):
        super().__init__()

        self.time_stamps = list(np.linspace(0,1,n_res+2))[1:-1]

        self.emb = ContinuousTimeEncoding(embed_dim=embed_dim, scale=1.0)
        
        self.d_model = d_model
        self.n_layers = n_layers
        self.embed_dim = embed_dim
        self.lin_in = nn.Linear(1, d_model)
        self.lin_out = [nn.Linear(d_model, 1) for _ in range(len(self.time_stamps)+1)]
        self.blocks = nn.ModuleList(
            [ContinuousTimeResidualBlock(d_model=d_model, embed_dim=embed_dim) for _ in range(n_layers)]
        )
        # Convert self.res_blocks to a ModuleList of ModuleLists for proper handling in forward pass
        self.res_blocks = [nn.ModuleList(
            [ResContinuousTimeResidualBlock(d_model=d_model, embed_dim=embed_dim, c = 1/(len(self.time_stamps)+n_layers)) for _ in range(n_layers)]
        ) for _ in self.time_stamps ]

    def forward(self, x, t):
        x = self.lin_in(x)
        t_emb = self.emb(t)
    
        for block in self.blocks:
            x = block(x, t_emb) + x
    
        z = self.lin_out[0](x)
    
        # Process for conditional blocks
        for time_stamp_index, time_stamp in enumerate(reversed(self.time_stamps)):
            mask = (t < time_stamp)
    
            if mask.any():  # Proceed only if any item meets the condition
                indices = torch.nonzero(mask, as_tuple=False).squeeze()

                if indices.numel() == 0:
                    break  # Exit the loop early if no indices match the mask
                
                # Ensure indices is at least 1D
                if indices.dim() == 0:
                    indices = indices.unsqueeze(0)
        
            
                # For multi-dimensional scatter, indices need to match the dimensions
                # Expanding indices to match 'x' dimensions for scatter
                expanded_indices_x = indices.unsqueeze(-1).expand(-1, x.size(1))
    
                x_masked = x[indices]
                t_emb_masked = t_emb[indices]
                z_masked = z[indices]
    
                for block in reversed(self.res_blocks[time_stamp_index]):
                    x_temp = block(x_masked, z_masked, t_emb_masked) + x_masked
    
                z_conditioned = self.lin_out[time_stamp_index + 1](x_temp)
    
                # Updating 'x' and 'z' based on the mask
                x = x.scatter(0, expanded_indices_x, x_temp)
                z = z.scatter(0, indices.unsqueeze(-1), z_conditioned)  # Ensuring 'z_conditioned' is properly shaped
    
        return z




###works 

def compute_gradient_trace(f,x,t):

    x = x.clone().detach().requires_grad_()

    def trace_of_jacobian(func, y):
        trace = torch.zeros(1, requires_grad=True)
        for i in range(y.shape[0]):
            vec = torch.zeros_like(y)
            vec[i] = 1.0
            trace = trace + torch.func.jvp(func, (y,), (vec,))[1][i]
        return trace
    

    trace_jacobian = trace_of_jacobian(lambda x: f(x.view(-1,1),t), x)
    
    trace_jacobian.backward() 
    
    return x.grad




class DDPM(nn.Module):

    def __init__(self, n_steps: int, minval: float = 1e-7, maxval: float = 0.8 ):#5e-3):
        super().__init__()
        assert 0 < minval < maxval <= 1
        assert n_steps > 0
        self.n_steps = n_steps
        self.minval = minval
        self.maxval = maxval

        self.register_buffer("beta", torch.linspace(minval, maxval, n_steps, dtype=torch.float64))
        self.register_buffer("alpha", (1 - self.beta).to(torch.float64))
        self.register_buffer("alpha_bar", self.alpha.cumprod(0).to(torch.float64))

#->########################################################################################

        self.time_ref = copy.deepcopy(self.beta)
        
        self.lambda_increments = torch.zeros(n_steps)
        self.Lambda = 0.0

        self.space_error = 0.0

        self.loss_count = 0

        self.loss_record = {}
        self.loss_cdf = None 
        self.loss_cdf_inv = None

        self.lambda_cdf = None 
        self.lambda_cdf_inv = None

        self.lambdas = []

        self.D = torch.zeros(self.n_steps)
        self.st_xt = torch.zeros(self.n_steps)
        self.stm1_zt = torch.zeros(self.n_steps)
        self.correction = torch.zeros(self.n_steps)
        
        
        self.D2 = torch.zeros(self.n_steps)
        self.n_time = torch.zeros(self.n_steps,dtype = int)
        self.n_time[0] = 1
#<-########################################################################################

    def reset_counters(self):

        self.D = torch.zeros(self.n_steps)
        self.D2 = torch.zeros(self.n_steps)
        self.n_time = torch.zeros(self.n_steps,dtype = int)
        self.n_time[0] = 1

        
    
    def diffusion_loss(self, model, inp):

        self.loss_count += 1
        
        device = inp.device
        batch_size = inp.shape[0]

        # Create the noise perturbation
        eps = torch.randn_like(inp, device=device)

        # Convert discrete time into a positional encoding embedding
        if self.loss_cdf_inv is None:
            t = torch.randint(0, self.n_steps, (batch_size,), device=device)
        else:
            t = torch.randint(0, self.n_steps, (batch_size,), device=device)
            #t = torch.tensor(find_closest_beta_indices(self.loss_cdf_inv(np.random.rand(batch_size)),self.beta),dtype = int,device=device)

        # Compute the closed form sample x_noisy after t time steps
        a_t = self.alpha_bar[t][:, None]
        x_noisy = torch.sqrt(a_t).float() * inp + torch.sqrt(1 - a_t).float() * eps

    
        sig2 = (self.beta[t]*(1-self.alpha_bar[t])/(1-self.alpha_bar[t-1]))
        C_score = self.beta[t]**2/self.alpha[t]

        C = C_score/(2*sig2)

        # Predict the noise added given time t
        eps_pred = model(x_noisy, self.beta[t].float())

        # Calculate loss for each pixel (assuming images) without reduction
        loss_func = nn.MSELoss(reduction='none')
        # Calculate mean loss across all dimensions except the batch dimension
        individual_losses = (C*loss_func(eps_pred, eps)).mean(dim=list(range(1, eps.ndim)))

        # Detach the losses for logging
        individual_losses_detached = individual_losses.detach()

        self.loss_record = {'beta': self.beta[t].float(),'pdf': individual_losses_detached}

        betas,betas_pdf = average_losses_for_beta(self.beta[t].float().numpy(),individual_losses_detached.numpy())

        total_loss = individual_losses.mean()


        return total_loss, x_noisy, eps_pred, t


    def forward_sample(self, inp: torch.Tensor, t: int) -> torch.Tensor:
        device = inp.device
        batch_size = inp.shape[0]

        # create the noise perturbation
        eps = torch.randn_like(inp, device=device)

        # convert discrete time into a positional encoding embedding
        t = t + torch.zeros(batch_size, dtype = int)

        # compute the closed form sample x_noisy after t time steps
        a_t = self.alpha_bar[t][:, None]
        x_noisy = torch.sqrt(a_t).float() * inp + torch.sqrt(1 - a_t).float() * eps

        return x_noisy
    
    def sample(self, model: nn.Module, n_samples: int = 128, early_stop: int = 5):
        with torch.no_grad():
            device = next(model.parameters()).device

            # start off with an intial random ensemble of particles
            x = torch.randn(n_samples, 1, device=device)

            # the number of steps is fixed before beginning training. unfortunately.
            for t in reversed(range(early_stop,self.n_steps)):
                # apply the same variance to all particles in the ensemble equally.
                a = self.alpha[t].repeat(n_samples)[:, None].float()
                abar = self.alpha_bar[t].repeat(n_samples)[:, None].float()

                # deterministic trajectory. eps_theta is similar to the Force on the particle
                eps_theta = model(x, torch.tensor([self.beta[t].float() ] * n_samples, dtype=torch.float))
                x_mean = (x - eps_theta * (1 - a) / torch.sqrt(1 - abar)) / torch.sqrt(
                    a 
                )
                sigma_t = torch.sqrt(1 - self.alpha[t].float())

                # sample a different realization of noise for each particle and propagate
                z = torch.randn_like(x)
                x = x_mean + sigma_t * z

            return x_mean  # clever way to skip the last noise addition

    def sample_s(self, model: nn.Module, n_samples: int = 128, early_stop: int = 5):
        device = next(model.parameters()).device

        # Initial random ensemble of particles
        x = torch.randn(n_samples, 1, device=device)

        # Fixed number of steps
        for t in reversed(range(early_stop, self.n_steps)):
            
            beta_t =  torch.tensor([self.beta[t].float() ] * n_samples).view(n_samples, 1) 
            alpha_bar_t =  torch.tensor([self.alpha_bar[t].float() ] * n_samples).view(n_samples, 1) 
            
            # Evaluate model-specific function s_t at current state x
            st_x = - model(x, torch.tensor([self.beta[t].float() ] * n_samples, dtype=torch.float))/torch.sqrt(1-alpha_bar_t)

            #drift =  0.5 * beta_t * x + 0.5 * beta_t * st_x
        
            # # Drift term computation
            drift =  0.5 * beta_t * x + beta_t * st_x

            # #drift = -beta_t * st_x

            # Compute the diffusion term
            diffusion = torch.sqrt(beta_t) * torch.randn_like(x)

            # Euler update for x
            x = x + drift + diffusion

            # x = x + drift

        return x

#->########################################################################################
    def resample_beta(self, dists, tau = 0.05):

        with torch.no_grad():
        
            # Step 1: Calculate the cumulative distribution of dists
            cumulative_dists = np.cumsum(dists.double())
    
            # Step 2: Generate an equally spaced new cumulative distribution
            equal_cumulative_dists = np.linspace(0, cumulative_dists[-1], len(dists), dtype = np.float64)
            
            # Step 3: Interpolate beta across the new distribution
            # Ensure beta is a numpy array for interpolation
            beta_np = self.beta.cpu().numpy() if isinstance(self.beta, torch.Tensor) else self.beta
            
            # Interpolate beta values on the new, equally spaced cumulative distribution
            interpolation_function = PchipInterpolator(cumulative_dists, beta_np)
            new_beta_values = interpolation_function(equal_cumulative_dists)
            
            # Update the beta tensor in the class
            self.beta = tau*torch.tensor(new_beta_values, dtype=torch.float64) + (1-tau)*self.beta


    def update_alpha(self):
        """Recalculate `alpha` and `alpha_bar` based on the current `beta` values in double precision."""
        # Convert beta to double precision if it's not already
        beta_double = self.beta.double()

        # Perform calculations in double precision
        self.alpha = (1 - beta_double).double()
        self.alpha_bar = self.alpha.cumprod(0).double()

    def update_schedule_s(self,model,xt,st_xt,t,tau=0.05,n_l_min=1, sign = 'plus'):

        xt = xt[t>0]
        st_xt = st_xt[t>0,:]
        t = t[t>0]

        n_samples = t.shape[0]

        beta_t =  self.beta[t].float().view(n_samples, 1) 
        #alpha_bar_t =  torch.tensor([self.alpha_bar[t].float() ] * n_samples).view(n_samples, 1) 

        alpha_bar_t =  self.alpha_bar[t].float().view(n_samples, 1)  #torch.tensor([self.alpha_bar[t].float() ] * n_samples).view(n_samples, 1)

        st_xt =  -st_xt/torch.sqrt(1-alpha_bar_t)
        

        
        xt = xt.clone().detach().requires_grad_()
     
        if sign != 'zero':
        
            correction = -0.5*beta_t*torch.stack([-compute_gradient_trace(model,xt[i],self.beta[t].float()[i])/torch.sqrt(1-alpha_bar_t[i]) for i in range(len(xt))])


        with torch.no_grad():

            zt =  xt + 0.5 * beta_t * xt + 0.5 * beta_t * st_xt


            stm1_zt = -model(zt,self.beta[t-1].float())/torch.sqrt(1-self.alpha_bar[t-1].float().reshape(-1,1))



            if sign == 'plus':
                
                v = (st_xt + correction - stm1_zt)**2

                self.correction.scatter_add_(0,t,correction.reshape(-1))

            elif sign == 'minus':

                v = (st_xt - correction - stm1_zt)**2

                self.correction.scatter_add_(0,t,correction.reshape(-1))

            else:

                v = (st_xt - stm1_zt)**2
                


            C = (1-self.alpha_bar[t-1].float().reshape(-1,1))
    
            v = v*C.float().reshape(-1,1)


            self.st_xt.scatter_add_(0,t,st_xt.reshape(-1))
            self.stm1_zt.scatter_add_(0,t,stm1_zt.reshape(-1))
            
        
    
            
            self.D.scatter_add_(0,t,v.reshape(-1))
            #self.D2.scatter_add_(0,t,(v**2).reshape(-1))
            self.n_time.scatter_add_(0,t,torch.ones_like(t))
        
            if torch.all(self.n_time[1:]>n_l_min) : 
    
                self.D = self.D[self.n_time>0]
                beta_interp = self.beta.float()[self.n_time>0]
    
              
                self.st_xt = self.st_xt[self.n_time>0]
                self.stm1_zt = self.stm1_zt[self.n_time>0]
                self.correction = self.correction[self.n_time>0]

                self.n_time = self.n_time[self.n_time>0]


                self.st_xt = self.st_xt/self.n_time
                self.stm1_zt = self.stm1_zt/self.n_time
                self.correction = self.correction/self.n_time

                
        
                self.D = torch.sqrt(self.D/self.n_time) 
    
        
                self.lambda_increments = self.D 
    
                self.Lambda = (1-tau)*self.Lambda + tau*torch.sum(self.D)

                self.lambdas.append(self.Lambda)
    
                L_bar = torch.sum(self.D)/(len(self.D)-1)
    
                self.space_error = (1-tau)*self.space_error + tau*torch.max(torch.abs(self.D[1:-1] - L_bar)/L_bar)
        
                
                self.resample_beta(self.lambda_increments,tau = tau)
                self.update_alpha()
        
                self.D = torch.zeros(self.n_steps)
                self.D2 = torch.zeros(self.n_steps)
                self.n_time = torch.zeros(self.n_steps,dtype = int)
                self.n_time[0] = 1

        

    def update_schedule(self,model,xt,st_xt,t,tau=0.05,n_l_min=1):
        
        xt = xt[t>0]
        st_xt = st_xt[t>0,:]
        t = t[t>0]
    
        stm1_xt = - model(xt,self.beta[t-1].float())/torch.sqrt(1-self.alpha_bar[t-1].float().reshape(-1,1))

        st_xt = - st_xt/torch.sqrt(1-self.alpha_bar[t].float().reshape(-1,1))

        v = (stm1_xt -st_xt)**2


        C = (1-self.alpha_bar[t-1].float().reshape(-1,1))

        v = v*C.float().reshape(-1,1)

        
        self.D.scatter_add_(0,t,v.reshape(-1))

        self.n_time.scatter_add_(0,t,torch.ones_like(t))
    
        if torch.all(self.n_time[1:]>n_l_min) :
            self.D = self.D[self.n_time>0]
            beta_interp = self.beta.float()[self.n_time>0]
            self.n_time = self.n_time[self.n_time>0]

    
            self.D = torch.sqrt(self.D/self.n_time) #*torch.cat([torch.zeros(1),(self.beta[1:] - self.beta[:-1])])

    
            self.lambda_increments = self.D #+ self.D2

            self.Lambda = (1-tau)*self.Lambda + tau*torch.sum(self.D)

            L_bar = torch.sum(self.D)/(len(self.D)-1)

            self.space_error = (1-tau)*self.space_error + tau*torch.max(torch.abs(self.D[1:-1] - L_bar)/L_bar)
    

            self.resample_beta(self.lambda_increments,tau = tau)
            self.update_alpha()
    
            self.D = torch.zeros(self.n_steps)
            self.D2 = torch.zeros(self.n_steps)
            self.n_time = torch.zeros(self.n_steps,dtype = int)
            self.n_time[0] = 1


    def compute_lambda_increments(self,model,x):

        D = torch.zeros(len(self.beta))

        for i,beta in enumerate(self.beta):

            if i > 0 :

                forward_samples = torch.sqrt(self.alpha_bar[i]).float()*x 
                forward_samples += torch.sqrt(1-self.alpha_bar[i]).float()*torch.randn_like(forward_samples)
    
                scores_jm1 = model(forward_samples, self.beta[i-1].float()+torch.zeros(forward_samples.shape[0]) )/torch.sqrt(1-self.alpha_bar[i-1]).reshape(-1,1).float()
                scores_j = model(forward_samples,  self.beta[i].float()+torch.zeros(forward_samples.shape[0]))/torch.sqrt(1-self.alpha_bar[i]).reshape(-1,1).float()
    
                sig2 = (self.beta[i]*(1-self.alpha_bar[i])/(1-self.alpha_bar[i-1]))
                C_score = self.beta[i]**2/self.alpha[i]
                C = C_score/(2*sig2)
                
                score_diff = (scores_j - scores_jm1)
                D[i] = torch.mean(torch.sqrt(score_diff**2*C.float().reshape(-1,1)))

            else:
                D[i] = 0

        return D


from tqdm import tqdm
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from typing import List, Any

def train_model(model, ddpm, X, n_epochs=1000, batch_size=128, seed=42, learn_schedule=False, learn_score=True, n_l_min=1, tau = 0.05, stop = None, max_steps = 1000, transport = False, lr = 1e-3, sign = 'plus' ):
    torch.manual_seed(seed)
    
    # Create a TensorDataset and DataLoader for batching
    dataset = TensorDataset(X)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    losses: List[float] = []
    samples: List[Any] = []
    step = 0
    avg_loss = None

    schedule_updated = 0

    n_epochs = int(max_steps//len(dataloader))
    
    #with tqdm(total=n_epochs * len(dataloader)) as pbar:

    with tqdm(total=max_steps) as pbar:
        for epoch in range(n_epochs):
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
            for x_batch, in dataloader:
                x_batch = x_batch.to(torch.float32)

                optimizer.zero_grad()
#->########################################################################################
                loss, xt, st_xt,t = ddpm.diffusion_loss(model, x_batch)
#<-########################################################################################  
                if learn_score:
                    loss.backward()
                    optimizer.step()
#->########################################################################################
                if learn_schedule:
                    schedule_updated += 1

                    if transport: 
                        ddpm.update_schedule_s(model,xt,st_xt,t,tau=tau,n_l_min=n_l_min, sign = sign)

                    else:           
                        with torch.no_grad():
                            ddpm.update_schedule(model,xt,st_xt,t,tau=tau,n_l_min=n_l_min)


#<-########################################################################################
                pbar.update(1)
                losses.append(loss.item())
                avg_loss = loss.item() if avg_loss is None else 0.95 * avg_loss + 0.05 * loss.item()
                
                if step % 10 == 0:
                    pbar.set_description(f"Iter: {step}. Avg Loss: {avg_loss:.4f}. Lambda: {ddpm.Lambda:.4f}. Space Error: {ddpm.space_error:.4f}")
                
                step += 1

                if stop is not None and step > n_l_min:

                    if ddpm.space_error < stop:

                        return 1


    return 0 



def generate_bimodal_gaussian_dataset(N, sigma=0.001, mu = 5):
    X = torch.cat([sigma * torch.randn(N//2, 1) + mu, sigma * torch.randn(N//2, 1) - mu])
    X = X[torch.randperm(X.size(0))]  # Shuffle the tensor
    return X

# Example usage:
N = 512  # Number of samples per Gaussian
delta = 0.1
mu = 6
X = generate_bimodal_gaussian_dataset(N, delta, mu)
plt.hist(X.flatten(), bins=50)
plt.show()

delta = 0.1
X = generate_bimodal_gaussian_dataset(N, delta, mu)


#define model
model_bimodal = BasicContinuousTimeModel(d_model = 128, n_layers = 1, embed_dim = 12, n_res = 15) # SimpleContinuousTimeResNet(1,10,1) # BasicDiscreteTimeModel(d_model=128, n_layers=5)

#define DDPM\
n_steps = 50
min_val = 1e-5

ddpm_bimodal = DDPM(n_steps=n_steps, minval = min_val, maxval = 0.8)


def generate_cantor_set(n, start=-0.5, end=0.5, current_iter=0):
    """Recursively generate points in the nth iteration of the Cantor set."""
    if current_iter == n:
        return [(start + end) / 2]  # Return the midpoint of the final segments
    else:
        third = (end - start) / 3
        left = generate_cantor_set(n, start, start + third, current_iter + 1)
        right = generate_cantor_set(n, end - third, end, current_iter + 1)
        return left + right

def generate_cantor_gaussian_dataset(n, N, sigma=0.001):
    """
    Generate a dataset with Gaussians centered at the nth centers of a Cantor set.

    Parameters:
    - n: The resolution of the Cantor set.
    - N: Number of samples per Gaussian.
    - sigma: Standard deviation of the Gaussians.

    Returns:
    - A shuffled torch tensor of the dataset.
    """
    cantor_points = generate_cantor_set(n)
    X = torch.cat([sigma * torch.randn(N, 1) + point for point in cantor_points])
    X = X[torch.randperm(X.size(0))]  # Shuffle the tensor
    return X

# Example usage:
n = 4 # Iteration of the Cantor set
N = 1000  # Number of samples per Gaussian
delta = 0.001
X = generate_cantor_gaussian_dataset(n, N//2**n, delta)
plt.hist(X.flatten(), bins=200)
plt.show()

#define model
model_cantor = BasicContinuousTimeModel(d_model = 128, n_layers = 5, embed_dim = 12, n_res = 1) #SimpleContinuousTimeResNet(1,3,1)  # BasicDiscreteTimeModel(d_model=128, n_layers=5)

#define DDPM
n_steps = 50
min_val = 1e-5
ddpm_cantor = DDPM(n_steps=n_steps, minval = min_val, maxval = 0.8)

plt.hist(ddpm_cantor.forward_sample(X,0).detach().numpy().flatten(), alpha = 0.2, density = True, bins = 200)




            
