import torch

from gen_neg_toy.utils import expand_tensor_dims_as, get_register_fn

_PRECONDITIONERS = {}
register_precond = get_register_fn(_PRECONDITIONERS)


class PrecondBase(torch.nn.Module):
    def __init__(self, sigma_min, sigma_max):
        super().__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def get_params(self, sigma):
        raise NotImplementedError

    def forward(self, model_forward, x, sigma, **model_kwargs):
        x = x.to(torch.float32)
        sigma = sigma.to(torch.float32)
        params = self.get_params(sigma)
        c_in = expand_tensor_dims_as(params["c_in"], x)
        c_out = expand_tensor_dims_as(params["c_out"], x)
        c_skip = expand_tensor_dims_as(params["c_skip"], x)
        c_noise = params["c_noise"]

        F_x = model_forward((c_in * x), c_noise, **model_kwargs)
        D_x = c_skip * x + c_out * F_x.to(torch.float32)
        return D_x


@register_precond(name="edm")
class EDMPrecond(PrecondBase):
    def __init__(
        self,
        sigma_data=0.5,  # Expected standard deviation of the training data.
        sigma_min       = 0,                # Minimum supported noise level.
        sigma_max       = float('inf'),     # Maximum supported noise level.
    ):
        super().__init__(sigma_min, sigma_max)
        self.sigma_data = sigma_data

    def get_params(self, sigma):
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
        c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
        c_noise = sigma.log() / 4
        return dict(c_skip=c_skip, c_out=c_out, c_in=c_in, c_noise=c_noise)


@register_precond(name="edm_2")
class EDM2Precond(PrecondBase):
    def __init__(
        self,
        sigma_data=0.5,  # Expected standard deviation of the training data.
        sigma_min       = 0,                # Minimum supported noise level.
        sigma_max       = float('inf'),     # Maximum supported noise level.
    ):
        super().__init__(sigma_min, sigma_max)
        self.sigma_data = sigma_data

    def get_params(self, sigma):
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma
        c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
        c_noise = sigma.log() / 4
        return dict(c_skip=c_skip, c_out=c_out, c_in=c_in, c_noise=c_noise)


@register_precond(name="edm_simple")
class EDMSimplePrecond(PrecondBase):
    def __init__(
        self,
        sigma_data=0.5,  # Expected standard deviation of the training data.
        sigma_min       = 0,                # Minimum supported noise level.
        sigma_max       = float('inf'),     # Maximum supported noise level.
    ):
        super().__init__(sigma_min, sigma_max)
        self.sigma_data = sigma_data

    def get_params(self, sigma):
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma
        c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
        c_noise = sigma.log() / 4
        return dict(c_skip=c_skip, c_out=c_out, c_in=c_in, c_noise=c_noise)


@register_precond(name="vp")
class VPPrecond(PrecondBase):
    def __init__(
        self,
        beta_d          = 19.9,         # Extent of the noise level schedule.
        beta_min        = 0.1,          # Initial slope of the noise level schedule.
        M               = 1000,         # Original number of timesteps in the DDPM formulation.
        epsilon_t       = 1e-5,         # Minimum t-value used during training.
    ):
        super().__init__(float(self.sigma(epsilon_t)), float(self.sigma(1)))
        self.beta_d = beta_d
        self.beta_min = beta_min
        self.M = M
        self.epsilon_t = epsilon_t

    def get_params(self, sigma):
        c_skip = 1
        c_out = -sigma
        c_in = 1 / (sigma ** 2 + 1).sqrt()
        c_noise = (self.M - 1) * self.sigma_inv(sigma)
        return dict(c_skip=c_skip, c_out=c_out, c_in=c_in, c_noise=c_noise)

    def sigma(self, t):
        t = torch.as_tensor(t)
        return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()

    def sigma_inv(self, sigma):
        sigma = torch.as_tensor(sigma)
        return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d


@register_precond(name="none")
class NoPrecond(PrecondBase):
    def __init__(
        self,
        sigma_min       = 0,                # Minimum supported noise level.
        sigma_max       = float('inf'),     # Maximum supported noise level.
    ):
        super().__init__(sigma_min, sigma_max)

    def get_params(self, sigma):
        return dict(c_skip=0, c_out=1, c_in=1, c_noise=1)


def get_precond(name):
    if name not in _PRECONDITIONERS:
        raise ValueError(f"Unknown precond: {name}")
    return _PRECONDITIONERS[name]