"""Heavily borrowed from https://github.com/yang-song/score_sde_pytorch"""
"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""
import abc
import torch
import numpy as np
from .utils import expand_tensor_dims_as


class SDE(abc.ABC):
    def __init__(self, sigma_min, sigma_max):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        #print(f"Initialized a {self.__class__.__name__} object with sigma_min={sigma_min} and sigma_max={sigma_max}.")

    def sigma(self, t):
        """Standard deviation of forward process."""
        return t

    def sigma_deriv(self, t):
        """Derivative of standard deviation of forward process."""
        return 1

    def sigma_inv(self, sigma):
        """Inverse of standard deviation of forward process. Goes from sigma to t."""
        return sigma

    def scale(self, t):
        """Scaling of forward process."""
        return 1

    def scale_deriv(self, t):
        """Derivative of scaling of forward process."""
        return 0

    @abc.abstractmethod
    def get_sigma_steps(self, num_steps, device):
        """Get a sequence of sigma values from sigma_min to sigma_max."""
        raise NotImplementedError()


class VPSDE(SDE):
    def __init__(self, sigma_min=None, sigma_max=None, epsilon_s=1e-3):
        """Construct a VP SDE.
        Args:
          beta_min: minimum noise level
          beta_max: maximum noise level
        """
        vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
        if sigma_min is None:
            sigma_min = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
        if sigma_max is None:
            sigma_max = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
        super().__init__(sigma_min, sigma_max)
        self.epsilon_s = epsilon_s

    def sigma(self, t):
        return (np.e ** (0.5 * self.beta_d * (t ** 2) + self.beta_min * t) - 1) ** 0.5

    def sigma_deriv(self, t):
        return 0.5 * (self.beta_min + self.beta_d * t) * (self.sigma(t) + 1 / self.sigma(t))

    def sigma_inv(self, sigma):
        return ((self.beta_min ** 2 + 2 * self.beta_d * (sigma ** 2 + 1).log()).sqrt() - self.beta_min) / self.beta_d

    def scale(self, t):
        return 1 / (1 + self.sigma(t) ** 2).sqrt()

    def scale_deriv(self, t):
        return -self.sigma(t) * self.sigma_deriv(t) * (self.scale(t) ** 3)

    def get_sigma_steps(self, num_steps, device):
        step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
        orig_t_steps = 1 + step_indices / (num_steps - 1) * (self.epsilon_s - 1)
        return self.sigma(orig_t_steps)


class VESDE(SDE):
    def __init__(self, sigma_min=0.02, sigma_max=100):
        """Construct a VE SDE."""
        super().__init__(sigma_min, sigma_max)

    def sigma(self, t):
        return t.sqrt()

    def sigma_deriv(self, t):
        return 0.5 / t.sqrt()

    def sigma_inv(self, sigma):
        return sigma ** 2

    def get_sigma_steps(self, num_steps, device):
        step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
        orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (num_steps - 1)))
        return self.sigma(orig_t_steps)


class iDDPMSDE(SDE):
    def __init__(self, sigma_min=0.002, sigma_max=81, M=1000, C_1=0.001, C_2=0.008):
        """Construct a iDDPM SDE."""
        super(sigma_min, sigma_max).__init__()
        self.M = M
        self.C_1 = C_1
        self.C_2 = C_2

    def get_sigma_steps(self, num_steps, device):
        step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
        u = torch.zeros(self.M + 1, dtype=torch.float64, device=device)
        alpha_bar = lambda j: (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
        for j in torch.arange(self.M, 0, -1, device=device): # M, ..., 1
            u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=self.C_1) - 1).sqrt()
        u_filtered = u[torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)]
        sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
        return sigma_steps


class EDMSDE(SDE):
    def __init__(self, sigma_min=0.002, sigma_max=80, rho=7):
        """Construct an EDM SDE."""
        super().__init__(sigma_min, sigma_max)
        self.rho = rho

    def get_sigma_steps(self, num_steps, device):
        step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
        return (self.sigma_max ** (1 / self.rho) + step_indices / (num_steps - 1) * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))) ** self.rho