import math
import torch
import numpy as np
import scipy.optimize
from typing import Union, Tuple


def extract(arr_func, x, t: Union[torch.Tensor, int], ndim=None, dtype=None, device=None):
    ndim, dtype, device = ndim or x.ndim, dtype or x.dtype, device or x.device
    if isinstance(t, int):
        t = torch.as_tensor([t, ], dtype=torch.int64, device=device)
    elif isinstance(t, torch.Tensor):
        t = t.to(dtype=torch.int64)
    dims = [t.shape[0], ] + [1 for _ in range(ndim - 1)]
    if callable(arr_func):
        outs = arr_func(t)
        if isinstance(outs, tuple):
            return [out.to(dtype=dtype, device=device).reshape(dims) for out in outs]
        else:
            return outs.to(dtype=dtype, device=device).reshape(dims)
    else:
        assert t.ndim == 1 and arr_func.ndim == 1 and t.dtype == torch.int64
        return arr_func.to(dtype=dtype, device=device).gather(0, t).reshape(dims)


def flat_mean(x):
    reduce_dims = list(range(1, x.ndim))
    return x.mean(dim=reduce_dims)


def rand_zero(x, prob):
    zero_mask = (x.shape[0], ) + (1, ) * (x.ndim - 1)
    return torch.where(torch.rand(zero_mask, device=x.device) < prob, x, torch.zeros_like(x))


@torch.jit.script
def poisson_kl(rate_1, rate_2, eps: float = 1e-12):
    """
    Bregman divergence induced by (generalized) negative entropy on non-negative orthant
    """
    return (rate_1 - rate_2).neg().add(rate_1.mul(
        rate_1.clamp(min=eps).log() - rate_2.clamp(min=eps).log()))


@torch.jit.script
def poisson_loglik(x, rate, start: int = 0, eps: float = 1e-12):
    return torch.lgamma(x - start + 1).neg() + (x - start) * rate.clamp(min=eps).log() - rate


@torch.jit.script
def approx_std_normal_cdf(x):
    """
    Reference:
    Page, E. “Approximations to the Cumulative Normal Function and Its Inverse for Use on a Pocket Calculator.”
     Applied Statistics 26.1 (1977): 75–76. Web.
    """
    return 0.5 * (1. + torch.tanh(math.sqrt(2. / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


@torch.jit.script
def discretized_gaussian_loglik(
        x, means, log_scale, precision: float = 1./255,
        cutoff: Union[float, Tuple[float, float]] = (-0.999, 0.999), tol: float = 1e-12):
    if isinstance(cutoff, float):
        cutoff = (-cutoff, cutoff)
    # Assumes data is integers [0, 255] rescaled to [-1, 1]
    x_centered = x - means
    inv_stdv = torch.exp(-log_scale)
    upper = inv_stdv * (x_centered + precision)
    cdf_upper = torch.where(
        x > cutoff[1], torch.as_tensor(1, dtype=torch.float32, device=x.device), approx_std_normal_cdf(upper))
    lower = inv_stdv * (x_centered - precision)
    cdf_lower = torch.where(
        x < cutoff[0], torch.as_tensor(0, dtype=torch.float32, device=x.device), approx_std_normal_cdf(lower))
    log_probs = torch.log(torch.clamp(cdf_upper - cdf_lower - tol, min=0).add(tol))
    return log_probs


@torch.jit.script
def categorical_kl(logits_1, logits_2, eps: float = 1e-12):
    return torch.softmax(logits_1, dim=-1) * (
            torch.log_softmax(logits_1 + eps, dim=-1) - torch.log_softmax(logits_2 + eps, dim=-1))


def stable_log1mexp(x):
    """
    numerically stable version of log(1-exp(x)), x<0
    """
    assert torch.all(x < 0.)
    return torch.where(
        x < -9,
        torch.log1p(torch.exp(x).neg()),
        torch.log(torch.expm1(x).neg()))


def parse_range(seq):
    if seq is None:
        return None
    else:
        assert hasattr(seq, "__len__") and len(seq) == 2 and hasattr(seq, "__iter__")
        return tuple(seq)


def log_sigmoid(x):
    if x < -9 or 9:
        out = x
    elif x > 9:
        out = -np.exp(-x)
    else:
        out = -np.log(1 + np.exp(-x))
    return out


def input_check(dtype=torch.float64, cont=False):
    def check(func):
        def func_w_check(t):
            assert t.dtype == dtype
            if cont:
                assert torch.all(torch.logical_and(0 <= t, t <= 1))
            return func(t)
        return func_w_check
    return check


def _warmup_schedule(start, end, timesteps, warmup_frac):
    coefs = end * torch.ones(timesteps, dtype=torch.float64)
    warmup_time = int(timesteps * warmup_frac)
    coefs[:warmup_time] = torch.linspace(start, end, warmup_time, dtype=torch.float64)
    return coefs


def _signal_decay_sequence(schedule, start, end, timesteps):
    """
    schedules of signal decay
    alpha: cumulative decay coefficient
    beta: 1 - decay ratio
    """
    if schedule == "quad":
        coefs = torch.linspace(start ** 0.5, end ** 0.5, timesteps, dtype=torch.float64) ** 2
    elif schedule == "linear":
        coefs = torch.linspace(start, end, timesteps, dtype=torch.float64)
    elif schedule == "linear2":
        _start, _end = 1 - math.sqrt(1 - start), 1 - math.sqrt(1 - end)
        coefs = torch.linspace(_start, _end, timesteps, dtype=torch.float64)
        coefs = coefs * (2 - coefs)
    elif schedule == "warmup10":
        coefs = _warmup_schedule(start, end, timesteps, 0.1)
    elif schedule == "warmup50":
        coefs = _warmup_schedule(start, end, timesteps, 0.5)
    elif schedule == "const":
        coefs = torch.full((timesteps,), fill_value=start, dtype=torch.float64)
    elif schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        coefs = 1. / torch.linspace(timesteps, 1, timesteps, dtype=torch.float64)
    elif schedule == "cosine":
        coefs = end + (start - end) * torch.cos(0.5 * math.pi * torch.linspace(0, 1, timesteps, dtype=torch.float64))
    elif schedule == "cosine2":
        coefs = end + 0.5 * (start - end) * (
                torch.cos(math.pi * torch.linspace(0, 1, timesteps, dtype=torch.float64)) + 1)
    elif schedule == "cosine_improved":
        # from Nichol, Alexander Quinn, and Prafulla Dhariwal.
        # "Improved denoising diffusion probabilistic models." International Conference on Machine Learning. PMLR, 2021.
        coefs = torch.cos(0.5 * math.pi * torch.linspace(
            1. / timesteps, 1, timesteps, dtype=torch.float64
        ).add(0.008).div(1.008))
        coefs /= math.cos(0.004 * math.pi / 1.008)
    elif schedule == "invcdf_uniform":
        coefs = (2 * torch.arange(timesteps - 1, -1, -1, dtype=torch.float64) + 1) /\
                (torch.arange(timesteps, 0, -1, dtype=torch.float64) ** 2)
    elif schedule == "sigmoid":
        _start, _end = math.log(start) - math.log1p(-start), math.log(end) - math.log1p(-end)
        coefs = torch.sigmoid(torch.linspace(_start, _end, timesteps, dtype=torch.float64))
    else:
        raise NotImplementedError(schedule)
    assert coefs.shape == (timesteps, ) and coefs.dtype == torch.float64
    return coefs


def infer_alpha_from_beta(beta):
    return torch.cumprod(1. - beta, dim=0).sqrt()


def infer_alpha_from_logsnr(logsnr, signal_stat=1.):
    if not isinstance(logsnr, torch.Tensor):
        logsnr = torch.as_tensor(logsnr)
    return torch.exp(logsnr) / signal_stat


def infer_beta_end_from_logsnr(logsnr_end, beta_start, signal_stat, timesteps=1000):
    def logsnr_fn(beta_end):
        a, b = 1 - beta_start, 1 - beta_end
        fb = 0.5 * timesteps * ((b * math.log(b) - a * math.log(a)) / (b - a) - 1)
        target = logsnr_end - np.log(signal_stat)
        return fb - target
    return scipy.optimize.fsolve(logsnr_fn, np.array(0.00015), xtol=1e-6, maxfev=1000)[0].item()  # noqa


def _get_decay_sequence(schedule, start, end, timesteps, signal_stat=None, **kwargs):
    """
    returns decay coefficient alphas s.t. (z_t - \alpha_t x_0) \perp x_0, for any t
    and (legacy) betas, i.e. variance of additive noise in Gaussian Diffusion
    schedule must be in the format of {y_type}_{decay_schedule}
    y_type:
        alpha: cumulative signal decay factor
        beta: beta_t = 1 - \alpha_t / \alpha_{t-1}
        alpha2: alpha square
        logsnr:
            (Poisson Thinning) log(alpha * signal_mean/signal_peak)
            (Gaussian Diffusion) 2 * (log(alpha) - log(sqrt(1 - alpha)))
    """

    y_type, schedule = schedule.split("_", maxsplit=1)
    coefs = _signal_decay_sequence(schedule, start, end, timesteps)

    if y_type == "beta":
        betas = coefs
        alphas = infer_alpha_from_beta(betas)
    else:
        if y_type == "alpha":
            alphas = coefs
        elif y_type == "alpha2":
            alphas = torch.sqrt(coefs)
        elif y_type == "logsnr":
            alphas = torch.exp(coefs) / signal_stat
        else:
            raise NotImplementedError(y_type)
        betas = 1. - torch.cat([torch.atleast_1d(alphas[0]), alphas[1:].div(alphas[:-1])]) ** 2
    if schedule in ("cosine_improved", "invcdf_uniform"):
        betas.clamp_(max=0.999)
        alphas = infer_alpha_from_beta(betas)
    return {"betas": betas.numpy(), "alphas": alphas.numpy()}


def _get_decay_function(schedule, start, end, timesteps, signal_stat=None, **kwargs):

    y_type, schedule = schedule.split("_", maxsplit=1)
    if y_type == "beta":
        assert schedule in ("linear", )

    schedule_fn = None
    if schedule == "cosine":
        if y_type == "alpha":
            assert 0. <= start <= 1. and 0 <= end <= 1.
            _start, _end = math.acos(start), math.acos(end)

            @input_check(cont=True)
            def schedule_fn(t):
                return torch.cos(_start + (_end - _start) * t)

        else:
            @input_check(cont=True)
            def schedule_fn(t):
                return end + (start - end) * torch.cos(0.5 * math.pi * t)

    elif schedule == "cosine2":
        if y_type == "alpha":
            assert 0. <= start <= 1. and 0 <= end <= 1.
            _start, _end = math.acos(2 * start - 1), math.acos(2 * end - 1)

            @input_check(cont=True)
            def schedule_fn(t):
                return torch.cos(_start + (_end - _start) * t).add(1.).div(2.)
        else:
            @input_check(cont=True)
            def schedule_fn(t):
                return end + 0.5 * (start - end) * (torch.cos(math.pi * t) + 1)

    elif schedule == "cosine_improved":
        assert y_type == "alpha"

        @input_check(cont=True)
        def schedule_fn(t):
            return torch.cos(0.5 * math.pi * (t + 0.008).div(1.008))

    elif schedule == "linear":
        start, end = 1 - start, 1 - end

        @input_check(cont=True)
        def schedule_fn(t):
            k, b = end - start, start
            x = k * t + b
            return torch.exp(0.5 * timesteps / k * (x * x.log() - b * math.log(b) - k * t))

    else:
        raise NotImplementedError(schedule)

    out_dict = {f"{y_type}_fn": schedule_fn}
    alpha_fn = None
    if y_type == "beta":
        alpha_fn = schedule_fn

    elif y_type == "logsnr":

        @input_check(cont=True)
        def alpha_fn(t):
            return infer_alpha_from_logsnr(
                schedule_fn(t), signal_stat=signal_stat)

    if alpha_fn is not None:
        out_dict = {"alpha_fn": alpha_fn}
    return out_dict


def get_decay_schedule(schedule, timesteps, return_function=False, **kwargs):
    assert schedule.split("_")[0] in ("beta", "alpha", "logsnr")
    var_type = "" if {"start", "end"}.issubset(kwargs) else schedule.split("_", maxsplit=1)[0] + "_"
    start, end = kwargs[var_type + "start"], kwargs[var_type + "end"]
    if end == "auto":
        assert "logsnr_end" in kwargs
        if schedule.startswith("alpha"):
            kwargs["alpha_end"] = end = infer_alpha_from_logsnr(kwargs["logsnr_end"], kwargs["signal_stat"])
        elif schedule.startswith("beta"):
            kwargs["beta_end"] = end = infer_beta_end_from_logsnr(
                kwargs["logsnr_end"], start, kwargs["signal_stat"], timesteps)
            
    return (_get_decay_function if return_function else _get_decay_sequence)(
        schedule, start, end, timesteps, **kwargs) 