import torch
import math
import numpy as np
import tqdm


def LIM_sampler(args, config, x, y, model, sde, levy, sde_clamp=None,
                masked_data=None, mask=None, t0=None, device='cuda'):

    if args.sample_type not in ['ode', 'sde_vanilla', 'sde_exact']:
        raise Exception("Invalid sample type")

    if args.solver_type not in ['euler_maruyama', 'exponential_integrator']:
        raise Exception("Invalid solver type")

    is_isotropic= config.diffusion.is_isotropic
    steps = args.nfe
    method = args.sample_type
    solver = args.solver_type
    eta = args.eta
    eps = 1e-5

    if sde_clamp is None:
        sde_clamp = config.sampling.sde_clamp

    if method != 'sde_exact':
        eta = None

    if solver == 'euler_maruyama':
        timesteps = torch.linspace(sde.T, eps, steps + 1).to(device)
        exponential = False
    elif solver == 'exponential_integrator':
        timesteps = torch.pow(torch.linspace(np.sqrt(sde.T), np.sqrt(eps), steps + 1), 2).to(device)
        exponential = True

    if sde.alpha >= 1.7:
        t0 = 0.70
    elif sde.alpha >= 1.4:
        t0 = 0.85
    else:
        t0 = 0.95

    def eta_adaptive(t, eta):
        eta_t = eta * torch.ones_like(t)
        if not args.disable_large_eta:
            eta_t += (1.0 - eta) * torch.relu((t - t0) / (1.0 - t0))
        eta_t += 2 * eta * (torch.sigmoid(20 * t) - 1.0)
        return eta_t

    def score_model(x, t):
        if config.model.is_conditional:
            out = model(x, t, y)
        else:
            out = model(x, t)
        return out

    def ode_score_update(x, s, t, exponential=True):
        score_s = score_model(x, s) * torch.pow(sde.marginal_std(s), -(sde.alpha-1))[:, None, None, None]
        a = torch.exp(sde.marginal_log_mean_coeff(t) - sde.marginal_log_mean_coeff(s))
        beta_step = sde.beta(s) * (s - t)

        if exponential:
            x_coeff = a
            score_coeff = sde.alpha * (a - 1.0)
        else:
            x_coeff = 1.0 + beta_step/sde.alpha
            score_coeff = beta_step

        x_t = x_coeff[:, None, None, None] * x + score_coeff[:, None, None, None] * score_s
        return x_t

    def sde_score_update(x, s, t, exponential=True, eta=None, is_terminal_step=False):
        score_s = score_model(x, s) * torch.pow(sde.marginal_std(s), -(sde.alpha-1))[:, None, None, None]
        a = torch.exp(sde.marginal_log_mean_coeff(t) - sde.marginal_log_mean_coeff(s))
        beta_step = sde.beta(s) * (s - t)

        if is_isotropic:
            e_L = levy.sample(alpha=sde.alpha, size=x.shape, is_isotropic=True, clamp=sde_clamp).to(device)
        else:
            e_L = torch.clamp(levy.sample(alpha=sde.alpha, size=x.shape, is_isotropic=False, clamp=None).to(device),
                              min=-sde_clamp, max=sde_clamp)

        if eta is not None:
            eta = eta_adaptive(s, eta)
            if is_terminal_step:
                eta = torch.zeros_like(eta)

        if exponential:
            x_coeff = a
            noise_coeff = torch.pow(torch.pow(a, sde.alpha) - 1.0, 1.0/sde.alpha)
            if eta is None:
                score_coeff = sde.alpha ** 2 * (a - 1.0)
            else:
                score_coeff = sde.alpha * (1.0 + eta) * (a - 1.0)
        else:
            x_coeff = 1.0 + beta_step/sde.alpha
            noise_coeff = torch.pow(beta_step, 1.0/sde.alpha)
            if eta is None:
                score_coeff = sde.alpha * beta_step
            else:
                score_coeff = (1.0 + eta) * beta_step

        if eta is not None:
            noise_coeff *= torch.pow(eta, 1/sde.alpha)

        x_t = x_coeff[:, None, None, None] * x + score_coeff[:, None, None, None] * score_s + noise_coeff[:, None, None, None] * e_L
        return x_t

    with torch.no_grad():
        for i in tqdm.tqdm(range(steps)):
            vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + 1]

            if method == 'ode':
                x = ode_score_update(x, vec_s, vec_t, exponential=exponential)
            else:
                x = sde_score_update(x, vec_s, vec_t, exponential=exponential,
                                     eta=eta, is_terminal_step=(i == steps-1))

                if config.sampling.clamp_threshold:
                    size = x.shape
                    l = len(x)
                    x = x.reshape((l, -1))
                    indices = x.norm(dim=1) > config.sampling.clamp_threshold
                    x[indices] = x[indices] / x[indices].norm(dim=1)[:, None] * config.sampling.clamp_threshold
                    x = x.reshape(size)

    return x
