import torch


class SDE_Brownian_manifolds:
    def __init__(self, sigma_min, sigma_max, tau_min, tau_max, N, T, sampler=None, drift_mode = 'zero'):

        super().__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.tau_min = tau_min
        self.tau_max = tau_max
        self.N = N
        self.T = T
        self.dt = self.T/self.N
        if drift_mode == 'zero':
            self.func_b = lambda x: 0 * x
        elif drift_mode == 'linear':
            self.func_b = lambda x: -x

        self.sampler = sampler

        self.g_0 = self.sde(None, torch.tensor([0.]))[1]
        self.g_T = self.sde(None, torch.tensor([self.T]))[1]

    def get_diffusion(self, t):
        diffusion = self.sigma_min + t / self.T * (self.sigma_max - self.sigma_min)
        return diffusion
    
    def get_sigma(self, t):
        sigma = self.sigma_min + t / self.T * (self.sigma_max - self.sigma_min)
        return sigma
    
    def get_tau_scheduler(self, t):
        tau = self.tau_min + t / self.T * (self.tau_max - self.tau_min)
        return tau
    
    def drift_b(self, x): # b = \nabla U(x)
        return self.func_b(x).to(x)
    
    def sde(self, x, diff_t):
        # drift = sigma_t^2 * b(x, t), shape : [bsz, dim]
        # diffusion = sigma_t, shape : [bsz]
        # where sigma_t = sigma_min * (sigma_max/sigma_min)^(t/T)
        diffusion = self.sigma_min + diff_t / self.T * (self.sigma_max - self.sigma_min)
        # if self.sampler == 'CHMC':  
        #     diffusion = torch.sqrt(2 * diffusion)  # For CHMC, we use sqrt(2 * sigma_t) as the diffusion coefficient (for matching overdamped one)
        
        if x is None:
            drift = 0.
        else:
            temp = self.func_b(x).to(x)
            if len(temp.shape) == 2:
                drift = diffusion.reshape(-1, 1)**2 * temp # sigma_t^2 * b
            else:
                drift = diffusion.reshape(-1, 1, 1) ** 2 * temp

        return drift, diffusion.to(x)

    def reverse(self, score_fn, underdamped=False):
        N = self.N
        T = self.T
        sde_fn = self.sde

        # Build the class for reverse-time SDE.
        if underdamped:
            class RSDE(self.__class__):
                def __init__(self):
                    self.N = N
                    self.T = T

                def drift_score(self, x, v, score_t):
                    score = score_fn(torch.cat([x, v], dim = -1), label = score_t)
                    return score

                def sde(self, x, v, score_t, diff_t):
                    drift, diffusion = sde_fn(x, diff_t) # It should be sigma_k^2

                    score = score_fn(torch.cat([x, v], dim=-1), label = score_t)

                    if len(score.shape) == 2:
                        drift = drift - diffusion[:, None] ** 2 *   score
                    else:
                        drift = drift -  diffusion[:, None, None] ** 2 * score

                    return drift, diffusion
        else:
            class RSDE(self.__class__):
                def __init__(self):
                    self.N = N
                    self.T = T

                def sde(self, x, score_t, diff_t):
                    """Create the drift and diffusion functions for the reverse SDE."""
                    drift, diffusion = sde_fn(x, diff_t) # It should be sigma_k^2 
                    score = score_fn(x, score_t) 

                    if len(score.shape) == 2:
                        drift = drift - diffusion[:, None] ** 2 *   score
                    else:
                        drift = drift -  diffusion[:, None, None] ** 2 * score

                    return drift, diffusion

        return RSDE()

