import numpy as np
import torch


def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
    betas = np.linspace(beta_start, beta_end,
                        num_diffusion_timesteps, dtype=np.float64)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas


def extract(a, t, x_shape):
    """Extract coefficients from a based on t and reshape to make it
    broadcastable with x_shape."""
    bs, = t.shape
    assert x_shape[0] == bs, f"{x_shape[0]}, {t.shape}"
    out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long())
    assert out.shape == (bs,)
    out = out.reshape((bs,) + (1,) * (len(x_shape) - 1))
    return out



def denoising_step(xt, t, t_next, *,
                   models,
                   logvars,
                   b,
                   img_prompt=None,
                   sampling_type='ddim',
                   eta=0.0,
                   learn_sigma=False,
                   index=None,
                   t_edit=0,
                   hs_coeff=(1.0),
                   delta_h=None,
                   ignore_timestep=False,
                   image_space_noise=0,
                   ):

    # Compute noise and variance
    model = models
    et, et_modified, delta_h, middle_h = model(xt, t, img_prompt, index=index, t_edit=t_edit, hs_coeff=hs_coeff, delta_h=delta_h, ignore_timestep=ignore_timestep)
    if learn_sigma:
        et, logvar_learned = torch.split(et, et.shape[1] // 2, dim=1)
        if index is not None:
            et_modified, _ = torch.split(et_modified, et_modified.shape[1] // 2, dim=1)
        logvar = logvar_learned
    else:
        # this compute
        logvar = extract(logvars, t, xt.shape)

    if type(image_space_noise) != int:
        if t[0] >= t_edit:
            index = 0
            if type(image_space_noise) == torch.nn.parameter.Parameter:
                et_modified = et + image_space_noise * hs_coeff[1]
            else:
                # print(type(image_space_noise))
                temb = models.module.get_temb(t)
                et_modified = et + image_space_noise(et, temb) * 0.01

    # Compute the next x
    bt = extract(b, t, xt.shape)
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)

    if t_next.sum() == -t_next.shape[0]:
        at_next = torch.ones_like(at)
    else:
        at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)

    xt_next = torch.zeros_like(xt)
    if sampling_type == 'ddpm':
        weight = bt / torch.sqrt(1 - at)

        mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
        noise = torch.randn_like(xt)
        mask = 1 - (t == 0).float()
        mask = mask.reshape((xt.shape[0],) + (1,) * (len(xt.shape) - 1))
        xt_next = mean + mask * torch.exp(0.5 * logvar) * noise
        xt_next = xt_next.float()

    elif sampling_type == 'ddim':
        if index is not None:
            x0_t = (xt - et_modified * (1 - at).sqrt()) / at.sqrt()
        else:
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()

        # Deterministic.
        if eta == 0:
            xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et
        # Add noise. When eta is 1 and time step is 1000, it is equal to ddpm.
        else:
            c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * torch.randn_like(xt)

    return xt_next, x0_t, delta_h, middle_h



def ddim_interpolation_step(xt, t, t_next, *,
                   models,
                   b,
                   interpolation_step=4,
                   maintain=500,
                   ):

    # Compute noise and variance

    model = models
    # xt_two = torch.stack([xt[0], xt[-1]] , dim=0)
    # t_two = torch.stack([t[0], t[-1]] , dim=0)

    et = model(xt, t)
    # et_two = model(xt_two, t_two)

    # Compute the next x
    bt = extract(b, t, xt.shape)
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)

    if t_next.sum() == -t_next.shape[0]:
        at_next = torch.ones_like(at)
    else:
        at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)

    xt_next = torch.zeros_like(xt)

    if t[0] >= maintain:
        print("Maintain")

        def slerp(z1, z2, alpha):

            theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
            return (
                torch.sin((1 - alpha) * theta) / torch.sin(theta) * (z1)
                + torch.sin(alpha * theta) / torch.sin(theta) * (z2)
            )

        alpha = np.linspace(0.0, 1.0, interpolation_step)

        et_list = []

        for i in range(interpolation_step):
            et_list.append(slerp(et[0], et[-1], alpha[i]))

        et_slerp = torch.stack(et_list, dim=0).to(et.device)

        x0_t = (xt - et_slerp * (1 - at).sqrt()) / at.sqrt()
    else:
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
    
    xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et

    return xt_next

def ddim_interpolation_step2(xt, t, t_next, *,
                   models,
                   b,
                   interpolation_step=4,
                   maintain=500,
                   alpha=None
                   ):

    # Compute noise and variance

    model = models.module if hasattr(models, 'module') else models

    if t[0] >= maintain:
        et, et_modified = model.interpolation2(xt, t, index=1, maintain=maintain, alpha=alpha)
    else:
        et = model.interpolation2(xt, t)

    # Compute the next x
    bt = extract(b, t, xt.shape)
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)

    if t_next.sum() == -t_next.shape[0]:
        at_next = torch.ones_like(at)
    else:
        at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)

    xt_next = torch.zeros_like(xt)

    if t[0] >= maintain:
        print("Maintain")
        x0_t = (xt - et_modified * (1 - at).sqrt()) / at.sqrt()
    else:
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
    
    xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et

    return xt_next




                # 3가지 비교를 만들어야 함.
                # 1. e + random_like(e) , at both
                # 2. e + random_like(e) , at predicted_x0
                # 3. e + modified_e <- (h+random_like(h)) , at predicted_x0
                # 3개 다 std를 맞춰주기 위해서 1/루트2 를 곱해주자.

def denoising_step_three_compare(xt, t, t_next, *,
                   models,
                   logvars,
                   b,
                   sampling_type='ddim',
                   eta=0.0,
                   learn_sigma=False,
                   hybrid=False,
                   hybrid_config=None,
                   ratio=1.0,
                   out_x0_t=False,
                   index=None,
                   maintain=0,
                   out_et=False,
                   rambda=1.0,
                   add_noise_both=False,
                   add_noise_predicted_x0=False,
                   add_noise_h_space_both=False,
                   scaling=1.0,
                   mean=0.0,
                   ):

    # train_test_scale = 200 / 1000
    # scaling_factor = 0.03 * train_test_scale
    # Compute noise and variance
    if type(models) != list:
        model = models

        # et = model(xt, t, index, maintain)
        if index is not None:
            direct_delta_h = torch.randn((xt.shape[0], 512, 8, 8)) * 1.5
            direct_delta_h = direct_delta_h / torch.norm(direct_delta_h, p=2) * scaling + mean
            et, et_modified, delta_h, middle_h  = model(xt, t, index=index, t_edit=maintain, hs_coeff=(1.0,1.0), delta_h=direct_delta_h, ignore_timestep=False)
        else:
            et, et_modified, delta_h, middle_h  = model(xt, t)
        if learn_sigma:
            et, logvar_learned = torch.split(et, et.shape[1] // 2, dim=1)
            if index is not None:
                et_modified, _ = torch.split(et_modified, et_modified.shape[1] // 2, dim=1)
            logvar = logvar_learned
        else:
            logvar = extract(logvars, t, xt.shape)
    else:
        if not hybrid:
            et = 0
            logvar = 0
            if ratio != 0.0:
                et_i = ratio * models[1](xt, t, index, maintain)
                if learn_sigma:
                    et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                    logvar += logvar_learned
                else:
                    logvar += ratio * extract(logvars, t, xt.shape)
                et += et_i

            if ratio != 1.0:
                et_i = (1 - ratio) * models[0](xt, t, index, maintain)
                if learn_sigma:
                    et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                    logvar += logvar_learned
                else:
                    logvar += (1 - ratio) * extract(logvars, t, xt.shape)
                et += et_i

        else:
            for thr in list(hybrid_config.keys()):
                if t.item() >= thr:
                    et = 0
                    logvar = 0
                    for i, ratio in enumerate(hybrid_config[thr]):
                        ratio /= sum(hybrid_config[thr])
                        et_i = models[i+1](xt, t, index, maintain)
                        if learn_sigma:
                            et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                            logvar_i = logvar_learned
                        else:
                            logvar_i = extract(logvars, t, xt.shape)
                        et += ratio * et_i
                        logvar += ratio * logvar_i
                    break


    # Compute the next x
    bt = extract(b, t, xt.shape)
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)

    if t_next.sum() == -t_next.shape[0]:
        at_next = torch.ones_like(at)
    else:
        at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)

    xt_next = torch.zeros_like(xt)
    if sampling_type == 'ddpm':
        weight = bt / torch.sqrt(1 - at)

        mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
        noise = torch.randn_like(xt)
        mask = 1 - (t == 0).float()
        mask = mask.reshape((xt.shape[0],) + (1,) * (len(xt.shape) - 1))
        xt_next = mean + mask * torch.exp(0.5 * logvar) * noise
        xt_next = xt_next.float()

    elif sampling_type == 'ddim':
        if index is not None:
            x0_t = (xt - et_modified * (1 - at).sqrt()) / at.sqrt()
            if add_noise_h_space_both:
                et = et_modified
        else:
            if add_noise_predicted_x0:
                add_noise = torch.randn_like(et) * 1.5
                add_noise = add_noise / torch.norm(add_noise, p=2) * scaling + mean
                et2 = (et + add_noise)/(1+scaling**2)**(1/2) #* torch.sqrt(torch.std(et)**2/(torch.std(et)**2 + torch.std(add_noise)**2))
                x0_t = (xt - et2 * (1 - at).sqrt()) / at.sqrt()
            elif add_noise_both:
                add_noise = torch.randn_like(et) * scaling
                et = (et + add_noise)/(1+scaling**2)**(1/2) #* torch.sqrt(torch.std(et)**2/(torch.std(et)**2 + torch.std(add_noise)**2))
                x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            else:
                x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
        if eta == 0:
            xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et

        else:
            c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * torch.randn_like(xt)

    if out_x0_t == True:
        return xt_next, x0_t
    else:
        return xt_next