import sys

import math
import numpy

import torch


def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)  # 这个1就比较牛逼
    return a

curve = []
def generalized_steps(x, seq, model, b, **kwargs):
    method_dict = {'DDIM':gen_xt_next_1, 'S-PNDM':gen_xt_next_2, 'FO':gen_xt_next_4_, 'F-PNDM':gen_xt_next_4}
    method = kwargs.get('method')
    if method not in method_dict:
        print(f"No support method is named as {method}.")
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]
        et_list = []
        w_et_list = []
        for i, j in zip(reversed(seq), reversed(seq_next)):
            # if i > 992:
            #     continue
            xt = xs[-1].to('cuda')
            # et = model(xt, t)
            xt_next = method_dict[method](xt, i, j, model, b, x, et_list, w_et_list)

            x0_t = xt
            xs.append(xt_next.to('cpu'))
            x0_preds.append(x0_t.to('cpu'))

    return xs, x0_preds

def gen_xt_next_4_(xt, i, j, model, b, x, et_list, w_et_list):
    i_list = [max(i - (i - j) * k / 2, 0) for k in range(8)]
    n = x.size(0)
    t = (torch.ones(n) * i).to(x.device)
    next_t = (torch.ones(n) * j).to(x.device)
    at = compute_alpha(b, t.long())
    at_next = compute_alpha(b, next_t.long())
    if len(et_list) > 2:  # 线性多步法求et
        kt_1, et_1, w_et_1 = gen_func(xt, xt, i_list[0], i_list[0], i_list[0] - 1, model, b, x)
        x_delta = kt_1
        # et_ = model(xt, t)
        # w_et_ = (at_next - at) / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et_
        # # et_list.append(et_)
        # w_et_list.append(w_et_)
        # # et = (1 / 24) * (55 * et_list[-1] - 59 * et_list[-2] + 37 * et_list[-3] - 9 * et_list[-4])
        # w_et = (1 / 24) * (55 * w_et_list[-1] - 59 * w_et_list[-2] + 37 * w_et_list[-3] - 9 * w_et_list[-4])
        #
        # x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * xt - \
        #                             1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et_)
        et_list.append(x_delta)
        w_et_list.append(et_1)
        x_delta = (1 / 24) * (55 * et_list[-1] - 59 * et_list[-2] + 37 * et_list[-3] - 9 * et_list[-4])
    else:
        kt, x_delta = runge_kutta_(xt, i_list, model, b, x, et_list, w_et_list)  # Runge-Kutta法激活
        # et = model(xt, t)  # 向前欧拉法
        # et_list.append(kt)

        # x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * xt -
        #                             1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
    xt_next = xt + x_delta * (i_list[0] - i_list[2])
    # curve
    for k in range(int(i_list[0] - i_list[2])):
        xt_next_ = xt + x_delta * (k+1)
        curve.append(xt_next_.cpu().numpy())
    return xt_next

def runge_kutta_(xt, i_list, model, b, x, et_list, w_et_list):  # 有问题！
    kt_1, et_1, w_et_1 = gen_func(xt, xt, i_list[0], i_list[0], i_list[0]-1, model, b, x)

    xt_1 = xt + kt_1 * (i_list[0] - i_list[1])
    kt_2, et_2, w_et_2 = gen_func(xt_1, xt_1, i_list[1], i_list[1], i_list[1]-1, model, b, x)

    xt_2 = xt + kt_2 * (i_list[0] - i_list[1])
    kt_3, et_3, w_et_3 = gen_func(xt_2, xt_2, i_list[1], i_list[1], i_list[1]-1, model, b, x)

    xt_3 = xt + kt_3 * (i_list[0] - i_list[2])
    kt_4, et_4, w_et_4 = gen_func(xt_3, xt_3, i_list[2], i_list[2], i_list[2]-1, model, b, x)

    x_delta = (1 / 6) * (kt_1 + 2 * kt_2 + 2 * kt_3 + kt_4)
    # x_delta = kt_4
    et_list.append(kt_1)
    w_et_list.append(et_1)
    xt_next = xt + x_delta * (i_list[0] - i_list[2])

    return kt_1, x_delta

def gen_xt_next_4(xt, i, j, model, b, x, et_list, w_et_list):
    i_list = [max(i - (i - j) * k / 2, 0) for k in range(8)]
    n = x.size(0)
    t = (torch.ones(n) * i).to(x.device)
    next_t = (torch.ones(n) * j).to(x.device)
    at = compute_alpha(b, t.long())
    at_next = compute_alpha(b, next_t.long())
    if len(et_list) > 2:  # 线性多步法求et
        et_ = model(xt, t)
        w_et_ = (at_next - at) / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et_
        et_list.append(et_)
        w_et_list.append(w_et_)
        et = (1 / 24) * (55 * et_list[-1] - 59 * et_list[-2] + 37 * et_list[-3] - 9 * et_list[-4])
        w_et = (1 / 24) * (55 * w_et_list[-1] - 59 * w_et_list[-2] + 37 * w_et_list[-3] - 9 * w_et_list[-4])

        x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * xt - \
                                    1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
        # curve
        for k in range(int(i-j)):
            next_t_ = (torch.ones(n) * (i - k)).to(x.device)
            at_next_ = compute_alpha(b, next_t_.long())
            x_delta_ = (at_next_ - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next_.sqrt()))) * xt - \
                                        1 / (at.sqrt() * (
                                ((1 - at_next_) * at).sqrt() + ((1 - at) * at_next_).sqrt())) * et)
            xt_next_ = xt + x_delta_
            curve.append(xt_next_.cpu().numpy())
        #     print("ms:", len(curve))
    else:
        et, w_et = runge_kutta(xt, i_list, model, b, x, et_list, w_et_list)  # Runge-Kutta法激活
        # et = model(xt, t)  # 向前欧拉法
        # et_list.append(et)

        x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * xt -
                                    1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
    xt_next = xt + x_delta
    return xt_next

def runge_kutta(xt, i_list, model, b, x, et_list, w_et_list):  # 有问题！
    kt_1, et_1, w_et_1 = gen_func(xt, xt, i_list[0], i_list[0], i_list[2], model, b, x)

    xt_1 = xt + kt_1 / 2
    kt_2, et_2, w_et_2 = gen_func(xt_1, xt, i_list[1], i_list[0], i_list[2], model, b, x)

    xt_2 = xt + kt_2 / 2
    kt_3, et_3, w_et_3 = gen_func(xt_2, xt, i_list[1], i_list[0], i_list[2], model, b, x)

    xt_3 = xt + kt_3
    kt_4, et_4, w_et_4 = gen_func(xt_3, xt, i_list[2], i_list[0], i_list[2], model, b, x, [et_1, et_2, et_3], [w_et_1, w_et_2, w_et_3])

    # x_delta = (1 / 6) * (kt_1 + 2 * kt_2 + 2 * kt_3 + kt_4)
    x_delta = kt_4
    et_list.append(et_1)
    w_et_list.append(w_et_1)
    xt_next = xt + x_delta

    return et_4, w_et_4


def gen_xt_next_2(xt, i, j, model, b, x, et_list, w_et_list):
    n = x.size(0)
    t = (torch.ones(n) * i).to(x.device)
    next_t = (torch.ones(n) * j).to(x.device)
    at = compute_alpha(b, t.long())
    at_next = compute_alpha(b, next_t.long())
    if len(et_list) > 0:
        et_ = model(xt, t)
        w_et_ = (at_next - at) / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et_
        et_list.append(et_)
        w_et_list.append(w_et_)
        et = 0.5 * (3 * et_list[-1] - et_list[-2])
    else:
        et, et_1 = improved_eular(xt, i, j, model, b, x)
        et_list.append(et_1)

    x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * xt -
                                1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
    xt_next = xt + x_delta
    return xt_next


def improved_eular(xt, i, j, model, b, x):
    kt_1, et_1, _ = gen_func(xt, xt, i, i, j, model, b, x)

    xt_1 = xt + kt_1
    kt_2, et_2, _ = gen_func(xt_1, xt, j, i, j, model, b, x)
    et = (et_1 + et_2) / 2
    return et, et_1


def gen_xt_next_1(xt, i, j, model, b, x, et_list, w_et_list):
    n = x.size(0)
    t = (torch.ones(n) * i).to(x.device)
    next_t = (torch.ones(n) * j).to(x.device)
    at = compute_alpha(b, t.long())
    at_next = compute_alpha(b, next_t.long())
    # print(at[0,0,0,0].item(), at_next[0,0,0,0].item(), at[0,0,0,0]/at_next[0,0,0,0], i, j)
    et = model(xt, t)
    et_list.append(et)
    w_et_ = (at_next - at) / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et

    x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * xt -
                                1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
    # curve
    for k in range(int(i - j)):
        next_t_ = (torch.ones(n) * i - k).to(x.device)
        at_next_ = compute_alpha(b, next_t_.long())
        x_delta_ = (at_next_ - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next_.sqrt()))) * xt - \
                                      1 / (at.sqrt() * (
                        ((1 - at_next_) * at).sqrt() + ((1 - at) * at_next_).sqrt())) * et)
        xt_next_ = xt + x_delta_
        curve.append(xt_next_.cpu().numpy())

    xt_next = xt + x_delta
    return xt_next

def gen_func(xt, xt_b, k, i, j, model, b, x, tri_et=None, tri_w_et=None):
    """
    梯度函数封装
    xt->xt, i,j->t, model,b,x->fixed
    """
    n = x.size(0)
    tt = (torch.ones(n) * k).to(x.device)
    t = (torch.ones(n) * i).to(x.device)
    # print(t)
    next_t = (torch.ones(n) * j).to(x.device)
    at = compute_alpha(b, t.long())
    at_next = compute_alpha(b, next_t.long())

    et = model(xt, tt)  # noise
    w_et = (at_next - at) / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
    if tri_et is not None:  # 生成et
        et = (1 / 6) * (tri_et[0] + 2 * tri_et[1] + 2 * tri_et[2] + et)
        w_et = (1 / 6) * (tri_w_et[0] + 2 * tri_w_et[1] + 2 * tri_w_et[2] + w_et)  # 修改对应的xt
    x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * xt_b - \
                            1 / (at.sqrt() * (((1-at_next) * at).sqrt() + ((1-at) * at_next).sqrt())) * et)
    # curve
    if tri_et is not None:
        for w in range(int(i-j)):
            next_t_ = (torch.ones(n) * (i - w)).to(x.device)
            at_next_ = compute_alpha(b, next_t_.long())
            x_delta_ = (at_next_ - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next_.sqrt()))) * xt_b - \
                                          1 / (at.sqrt() * (
                            ((1 - at_next_) * at).sqrt() + ((1 - at) * at_next_).sqrt())) * et)
            xt_next_ = xt_b + x_delta_
            curve.append(xt_next_.cpu().numpy())
    #         print("rk:", len(curve))
    return x_delta, et, w_et


def ddpm_steps(x, seq, model, b, **kwargs):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        xs = [x]
        x0_preds = []
        betas = b
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = compute_alpha(betas, t.long())
            atm1 = compute_alpha(betas, next_t.long())
            beta_t = 1 - at / atm1
            x = xs[-1].to('cuda')

            output = model(x, t.float())
            e = output

            x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
            x0_from_e = torch.clamp(x0_from_e, -1, 1)
            x0_preds.append(x0_from_e.to('cpu'))
            mean_eps = (
                               (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
                       ) / (1.0 - at)

            mean = mean_eps
            noise = torch.randn_like(x)
            mask = 1 - (t == 0).float()
            mask = mask.view(-1, 1, 1, 1)
            logvar = beta_t.log()
            sample = mean + mask * torch.exp(0.5 * logvar) * noise
            xs.append(sample.to('cpu'))
    return xs, x0_preds
