import math
import wandb
import torch
import matplotlib.pyplot as plt

from .base import BaseVarianceSchedule


class ScalarCosineVPVarianceSchedule(BaseVarianceSchedule):
    def __init__(self, s: float, log_assets: bool):
        super(ScalarCosineVPVarianceSchedule, self).__init__()
        # Offset parameter to prevent instabilities near t = 0 or t = 1
        self.s = s
        self.pi_over_2 = math.pi / 2

        # Precompute denominator constant
        self.cos_0 = math.cos(
            self.pi_over_2 * self.s / (1 + self.s)
        )  # Used for normalization

        if log_assets:
            self.log_plots()

    def _bar_t(self, t):
        """Shifted and scaled time variable"""
        return (t + self.s) / (1 + self.s)

    def _safe_t(self, t):
        """Clamp time to avoid instabilities at boundaries"""
        return torch.clamp(t, min=1e-5, max=1.0 - 1e-5)

    def sigma_sq(self, t):
        t = self._safe_t(t)
        bar_t = self._bar_t(t)
        cos_val = torch.cos(self.pi_over_2 * bar_t)
        return 1.0 - (cos_val / self.cos_0).pow(2)

    def d_sigma_sq(self, t):
        t = self._safe_t(t)
        bar_t = self._bar_t(t)
        cos_val = torch.cos(self.pi_over_2 * bar_t)
        sin_val = torch.sin(self.pi_over_2 * bar_t)
        scale = self.pi_over_2 / (1 + self.s)

        return 2.0 * cos_val * sin_val * scale / (self.cos_0**2)

    def d_sigma_sq_t_over_sigma_sq_t(self, t):
        # Use clamping to avoid division by near-zero
        sigma_sq = self.sigma_sq(t).clamp(min=1e-10)
        d_sigma = self.d_sigma_sq(t)
        return d_sigma / sigma_sq

    def alpha(self, t):
        # Equivalent to sqrt(1 - sigma_sq(t)) but computed stably
        t = self._safe_t(t)
        bar_t = self._bar_t(t)
        cos_val = torch.cos(self.pi_over_2 * bar_t)
        return cos_val.abs() / self.cos_0

    def d_alpha(self, t):
        t = self._safe_t(t)
        bar_t = self._bar_t(t)
        sin_val = torch.sin(self.pi_over_2 * bar_t)
        scale = self.pi_over_2 / (1 + self.s)
        sign = torch.sign(
            torch.cos(self.pi_over_2 * bar_t)
        )  # for correct derivative of abs
        return -sign * sin_val * scale / self.cos_0

    def d_alpha_t_over_alpha_t(self, t):
        # Use clamping to avoid division by near-zero
        alpha = self.alpha(t).clamp(min=1e-10)
        d_alpha = self.d_alpha(t)
        return d_alpha / alpha

    def sigma(self, t):
        pass

    def dsigma(self, t):
        pass

    def log_plots(self):
        t = torch.linspace(0, 1, 101)
        log_plot([t, self.sigma_sq(t)], "sigma_sq")
        log_plot([t, self.sigma_sq(t).sqrt()], "sigma")
        log_plot([t, self.d_sigma_sq(t)], "d_sigma_sq")
        log_plot([t, self.d_sigma_sq(t).sqrt()], "d_sigma")
        log_plot([t, self.alpha(t)], "alpha")
        log_plot([t, self.d_alpha(t)], "d_alpha")
        log_plot([t, self.d_alpha_t_over_alpha_t(t)], "d_alpha_t_over_alpha_t")


def expand(input, target):
    """Adds dimension to input to match number of dimensions in target"""
    return input[(...,) + (None,) * (target.ndim - input.ndim)]


def log_plot(data, name):
    data = [e.numpy(force=True) for e in data]
    fig = plt.figure()
    plt.plot(*data)
    plt.grid()
    wandb.log({f"misc/{name}": wandb.Image(fig)})
    plt.close(fig)
