import numpy as np
import torch


def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a
# from torch_utils import distributed as dist

#----------------------------------------------------------------------------
# Proposed EDM sampler (Algorithm 2).

def edm_sampler(
    net, con_img, mask, randn_like=torch.randn_like,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
    # Adjust noise levels based on what's supported by the network.
    sigma_min = sigma_min
    sigma_max = sigma_max

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float32, device=con_img.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([torch.as_tensor(t_steps).to(con_img.device), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_next = con_img * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = torch.as_tensor(t_cur + gamma * t_cur).to(con_img.device)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)

        # Euler step.
        denoised = net(torch.cat([con_img, x_hat*mask+con_img*(1-mask)], dim=1), 1.0-t_hat)
        d_cur = (x_hat - denoised) #/ t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur
        # x_next = con_img*(1-mask) + x_next*mask

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = net(torch.cat([con_img, x_next*mask+con_img*(1-mask)], dim=1), 1.0-t_next)
            d_prime = (x_next - denoised) #/ t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
            # x_next = con_img*(1-mask) + x_next*mask

    return x_next

def ddpm_steps(x_cond, x, m, seq, model, b, gammas):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        xs = [x]
        x0_preds = []
        betas = b
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = compute_alpha(betas, t.long())
            atm1 = compute_alpha(betas, next_t.long())
            beta_t = 1 - at / atm1
            x = xs[-1].to('cuda')
            x = x.type(torch.float32)
            # print(t.type())
            # print(t)
            output = model(torch.cat([x_cond, x], dim=1), (torch.ones(n).to(x.device) * gammas[i].to(x.device)).to(x.device).float())
            e = output

            x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
            x0_from_e = torch.clamp(x0_from_e, -1, 1)
            x0_preds.append(x0_from_e.to('cpu'))
            mean_eps = (
                (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
            ) / (1.0 - at)

            mean = mean_eps
            noise = torch.randn_like(x)
            mask = 1 - (t == 0).float()
            mask = mask.view(-1, 1, 1, 1)
            logvar = beta_t.log()
            sample = mean + mask * torch.exp(0.5 * logvar) * noise
            print(mask)
            sample = x_cond*(1.0-m) + sample*m
            xs.append(sample.to('cpu'))
    return xs #, x0_preds

def generalized_steps(x_cond, x, mask, seq, model, b, gammas, **kwargs):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = compute_alpha(b, t.long())
            at_next = compute_alpha(b, next_t.long())
            xt = xs[-1].to('cuda')
            xt = xt.type(torch.float32).to(x.device)
            et = model(torch.cat([x_cond, xt], dim=1), (torch.ones(n).to(x.device) * gammas[i].to(x.device)).to(x.device).float())
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            x0_preds.append(x0_t.to('cpu'))
            c1 = (
                kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
            )
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            xt_next = x_cond*(1.-mask) + xt_next*mask
            xs.append(xt_next.to('cpu'))

    return xs#, x0_preds
