import torch
from scipy.special import beta as beta_func, betaln


def beta_pdf(x, alpha, beta):
    """Beta distribution PDF."""
    return x ** (alpha - 1) * (1 - x) ** (beta - 1) / beta_func(alpha, beta)

class TimeDistorter:
    def __init__(
        self,
        train_distortion,
        sample_distortion,
        mu=0,
        sigma=1,
        s=-0.54,
        alpha=1,
        beta=1,
    ):
        self.train_distortion = train_distortion
        self.sample_distortion = sample_distortion
        self.mu = mu
        self.sigma = sigma
        self.s = s
        self.alpha = alpha
        self.beta = beta
        print(
            f"TimeDistorter: train_distortion={train_distortion}, "
            f"sample_distortion={sample_distortion}, mu={mu}, sigma={sigma}, s={s}"
        )
        self.f_inv = None

    def train_ft(self, batch_size, device):
        t_uniform = torch.rand((batch_size, 1), device=device)
        t_distort = self.apply_distortion(t_uniform, self.train_distortion)
        return t_distort

    def apply_distortion(self, t, distortion_type):
        assert torch.all((t >= 0) & (t <= 1)), "t must be in the range (0, 1)"

        if distortion_type == "identity":
            ft = t
        elif distortion_type == "mode":
            ft = 1 - t - self.s * (torch.cos(torch.pi / 2 * t) ** 2 - 1 + t)
            ft = torch.clamp(ft, 0.0, 1.0)
        elif distortion_type == "cos":
            ft = (1 - torch.cos(t * torch.pi)) / 2
        elif distortion_type == "revcos":
            ft = 2 * t - (1 - torch.cos(t * torch.pi)) / 2
        elif distortion_type == "polyinc":
            ft = t**2
        elif distortion_type == "polydec":
            ft = 2 * t - t**2
        elif distortion_type == "polydec_1p4":
            ft = 2 * t - t ** (1.4)
        elif distortion_type == "polydec_1p6":
            ft = 2 * t - t ** (1.6)
        elif distortion_type == "polydec_1p8":
            ft = 2 * t - t ** (1.8)
        elif distortion_type == "polydec_1p9":
            ft = 2 * t - t ** (1.9)
        elif distortion_type == "polydec_2p1":
            ft = 2 * t - t ** (2.1)
        elif distortion_type == "polydec_2p2":
            ft = 2 * t - t ** (2.2)
        elif distortion_type == "polydec_2p3":
            ft = 2 * t - t ** (2.3)
        elif distortion_type == "polydec_2p4":
            ft = 2 * t - t ** (2.4)
        elif distortion_type == "polydec_2p6":
            ft = 2 * t - t ** (2.6)
        elif distortion_type == "polydec_2p8":
            ft = 2 * t - t ** (2.8)
        elif distortion_type == "beta":
            raise ValueError(f"Unknown distortion type: {distortion_type}")
        elif distortion_type == "adaptive":
            if self.f_inv is None:
                return t
            ft = self.f_inv(t.cpu().detach().numpy())
            ft = torch.tensor(ft).to(t.device)
            ft = ft.clamp(0.0, 1.0)
        else:
            raise ValueError(f"Unknown distortion type: {distortion_type}")

        return ft
