import numpy as np
from torch import nn
import torch


def clip_noise_schedule(alphas2, clip_value=0.001):
    """
    For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during
    sampling.
    """
    # alphas2 = np.concatenate([np.ones(1), alphas2], axis=0)

    alphas_step = alphas2[1:] / alphas2[:-1]

    alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.0)
    alphas2 = np.cumprod(alphas_step, axis=0)

    return alphas2


def polynomial_decay(timesteps: int, s=1e-5, clip_value=0.001, power=2.0):
    """
    A noise schedule based on a simple polynomial equation: 1 - x^power.
    from https://arxiv.org/abs/2203.17003
    """

    # steps = timesteps+1
    # x = np.linspace(0, steps, steps)
    steps = timesteps
    x = np.linspace(0, steps, steps + 1, dtype=np.float64)

    alphas2 = (1 - np.power(x / steps, power)) ** 2

    alphas2 = clip_noise_schedule(alphas2, clip_value=clip_value)

    precision = 1 - 2 * s

    alphas_bar = precision * alphas2 + s

    return alphas_bar


def cosine_decay(timesteps, s: float = 0.008, clip_value: float = 0.999):
    """
    cosine schedule with clipping
    from https://arxiv.org/abs/2102.09672
    """
    steps = np.linspace(0, timesteps, timesteps + 1, dtype=np.float64)
    f_t = np.cos((steps / timesteps + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_bar = f_t / f_t[0]

    # clipping trick for more stable noise schedule
    betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
    betas = np.clip(betas, a_min=0.0, a_max=clip_value)

    # recompute stable alphas
    alphas = 1 - betas
    alphas_bar = np.cumprod(alphas, axis=0)

    return alphas_bar


class NoiseSchedule(nn.Module):
    """
    Base class for noise schedules.
    """

    def __init__(
        self,
        T: int,
        alphas_bar=None,
        s: float = None,
        clip_value: float = 0.999,
        use_snr: bool = False,
        variance_type: str = "lower_bound",
    ):
        super().__init__()
        self.T = T  # timesteps
        self.alphas_bar = np.array(alphas_bar, dtype=np.float64)
        self.s = s
        self.clip_value = clip_value
        self.use_snr = use_snr
        self.variance_type = variance_type
        self.pre_compute_statistics()

    def forward(self, t: torch.tensor, stage: str = "fit"):
        device = t.device
        if len(t.shape) == 0:
            t = t.reshape(1)
        t = torch.round(t * self.T).long().cpu().numpy()
        t = np.where(t < 0, 0, t)

        t = np.where(t > self.T - 1, self.T - 1, t)
        if stage == "fit":
            params = {
                "alpha_bar": torch.from_numpy(self.alphas_bar[t])
                .to(device=device)
                .float(),
                "beta_bar": torch.from_numpy(self.betas_bar[t])
                .to(device=device)
                .float(),
                "sqrt_alpha_bar": torch.from_numpy(self.sqrt_alphas_bar[t])
                .to(device=device)
                .float(),
                "sqrt_beta_bar": torch.from_numpy(self.sqrt_betas_bar[t])
                .to(device=device)
                .float(),
            }
            if self.use_snr:
                params["gamma"] = (
                    torch.from_numpy(self.gamma[t]).to(device=device).float()
                )
        else:
            params = {
                "alpha_full": torch.from_numpy(self.alphas_full[t])
                .to(device=device)
                .float(),
                "beta_full": torch.from_numpy(self.betas_full[t])
                .to(device=device)
                .float(),
                "sqrt_alpha_full": torch.from_numpy(self.sqrt_alphas_full[t])
                .to(device=device)
                .float(),
                "sqrt_beta_full": torch.from_numpy(self.sqrt_betas_full[t])
                .to(device=device)
                .float(),
                "beta_full_square": torch.from_numpy(self.betas_full_square[t])
                .to(device=device)
                .float(),
                "sigma": torch.from_numpy(self.sigmas[t]).to(device=device).float(),
                "sqrt_sigma": torch.from_numpy(self.sqrt_sigmas[t])
                .to(device=device)
                .float(),
            }
            t = np.where(t > self.T - 2, self.T - 2, t)
            params.update(
                {
                    "alpha": torch.from_numpy(self.alphas[t]).to(device=device).float(),
                    "beta": torch.from_numpy(self.betas[t]).to(device=device).float(),
                    "sqrt_alpha": torch.from_numpy(self.sqrt_alphas[t])
                    .to(device=device)
                    .float(),
                    "sqrt_beta": torch.from_numpy(self.sqrt_betas[t])
                    .to(device=device)
                    .float(),
                    "beta_square": torch.from_numpy(self.betas_square[t])
                    .to(device=device)
                    .float(),
                }
            )
        return params

    def pre_compute_statistics(self):
        def sigmoid(x):
            return 1 / (1 + np.exp(-x))

        if self.use_snr:
            log_alphas_bar = np.log(self.alphas_bar)
            log_betas_bar = np.log(1 - self.alphas_bar)
            self.gamma = log_betas_bar - log_alphas_bar
            self.alphas_bar = sigmoid(-self.gamma)
            self.betas_bar = sigmoid(self.gamma)
            self.sqrt_alphas_bar = np.sqrt(self.alphas_bar)
            self.sqrt_betas_bar = np.sqrt(self.betas_bar)
        else:
            self.betas_bar = 1 - self.alphas_bar
            self.sqrt_alphas_bar = np.sqrt(self.alphas_bar)
            self.sqrt_betas_bar = np.sqrt(
                self.betas_bar
            )  # different from 1-sqrt(alphas_bar)
        self.alphas = self.alphas_bar[1:] / self.alphas_bar[:-1]
        self.alphas_full = np.concatenate([self.alphas_bar[:1], self.alphas])
        self.betas = 1.0 - self.alphas
        self.betas_full = 1.0 - self.alphas_full
        self.betas_square = self.betas**2
        self.betas_full_square = self.betas_full**2
        self.sqrt_betas = np.sqrt(self.betas)
        self.sqrt_betas_full = np.sqrt(self.betas_full)
        self.sqrt_alphas = np.sqrt(self.alphas)
        self.sqrt_alphas_full = np.sqrt(self.alphas_full)

        # pre-compute variance (sigma) of the reverse process if not learned
        self.sigmas = self.betas * (self.betas_bar[:-1] / self.betas_bar[1:])
        if self.variance_type == "lower_bound":
            # clip the first value for t=1 by appending the same var for t=2 as it will be 0 per definition
            self.sigmas = np.append(self.sigmas[0], self.sigmas)
        elif self.variance_type == "upper_bound":
            # we always replace the first value by the true variacne posterior to have a better likelihood of L_0
            # see https://arxiv.org/abs/2102.09672
            self.sigmas = np.append(self.sigmas[0], self.betas_full[1:])
        else:
            raise ValueError(
                "variance_type must be either 'lower_bound' or 'upper_bound' "
            )
        self.sqrt_sigmas = np.sqrt(self.sigmas)


class CosineSchedule(NoiseSchedule):
    """
    Cosine noise schedule.
    """

    def __init__(
        self,
        T: int = 1000,
        s: float = 0.008,
        clip_value: float = 0.999,
        use_snr: bool = False,
        variance_type: str = "lower_bound",
    ):
        super().__init__(
            T,
            cosine_decay(T, s=s, clip_value=clip_value),
            s=s,
            clip_value=clip_value,
            use_snr=use_snr,
            variance_type=variance_type,
        )


class PolynomialSchedule(NoiseSchedule):
    """
    Cosine noise schedule.
    """

    def __init__(
        self,
        T: int = 1000,
        s: float = 1e-5,
        clip_value: float = 0.001,
        use_snr: bool = False,
        variance_type: str = "lower_bound",
    ):
        super().__init__(
            T,
            polynomial_decay(T, s=s, clip_value=clip_value),
            s=s,
            clip_value=clip_value,
            use_snr=use_snr,
            variance_type=variance_type,
        )
