import torch


class SDE_Brownian_manifolds:
    def __init__(self, sigma_min, sigma_max, N, T):
        super().__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.N = N
        self.T = T
        self.dt = self.T/self.N
        self.func_b = lambda x: 0 * x

        self.g_0 = self.sde(None, torch.tensor([0.]))[1]
        self.g_T = self.sde(None, torch.tensor([self.T]))[1]

    def sde(self, x, t):
        
        diffusion = self.sigma_min + t / self.T * (self.sigma_max - self.sigma_min)
        
        if x is None:
            drift = 0.
        else:
            temp = self.func_b(x).to(x)
            if len(temp.shape) == 2:
                drift = diffusion.reshape(-1, 1)**2 * temp
            else:
                drift = diffusion.reshape(-1, 1, 1) ** 2 * temp

        return drift, diffusion.to(x)

    def reverse(self, score_fn):
        N = self.N
        T = self.T
        sde_fn = self.sde

        # Build the class for reverse-time SDE.
        class RSDE(self.__class__):
            def __init__(self):
                self.N = N
                self.T = T

            def sde(self, x, t):
                """Create the drift and diffusion functions for the reverse SDE."""
                drift, diffusion = sde_fn(x, t)
                score = score_fn(x, t)

                if len(score.shape) == 2:
                    drift = drift - diffusion[:, None] ** 2 * score
                else:
                    drift = drift - diffusion[:, None, None] ** 2 * score

                return drift, diffusion

        return RSDE()


