import torch
import torch.nn as nn
import numpy as np
# from mlp_mix_switch import MlpMixer, regularize,ridge_regularize
from mlp_mix_best import MlpMixer, regularize,ridge_regularize
# from mlp_mix_residual import MlpMixer, regularize,ridge_regularize
from synthetic import simulate_var, data_segmentation, simulate_lorenz_96
import datetime
from GC_draw_manager import GC_draw_manager
import matplotlib.pyplot as plt

print("begin")
#


def make_lag_config(GC, GC_est_lag_list):
    # Verify lag selection
    for i in range(len(GC)):
        # Get true GC
        GC_lag = np.zeros((5, len(GC)))
        GC_lag[:3, GC[i].astype(bool)] = 1.0

        # Get estimated GC
        # GC_est_lag = cmlp.GC(ignore_lag=False, threshold=False)[i].cpu().data.numpy().T[::-1]
        GC_est_lag = GC_est_lag_list[i].cpu().data.numpy().T[::-1]

        # Make figures
        fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
        axarr[0].imshow(GC_lag, cmap='Blues', extent=(0, len(GC), 5, 0))
        axarr[0].set_title('Series %d true GC' % (i + 1))
        axarr[0].set_ylabel('Lag')
        axarr[0].set_xlabel('Series')
        axarr[0].set_xticks(np.arange(len(GC)) + 0.5)
        axarr[0].set_xticklabels(range(len(GC)))
        axarr[0].set_yticks(np.arange(5) + 0.5)
        axarr[0].set_yticklabels(range(1, 5 + 1))
        axarr[0].tick_params(axis='both', length=0)

        axarr[1].imshow(GC_est_lag, cmap='Blues', extent=(0, len(GC), 5, 0))
        axarr[1].set_title('Series %d estimated GC' % (i + 1))
        axarr[1].set_ylabel('Lag')
        axarr[1].set_xlabel('Series')
        axarr[1].set_xticks(np.arange(len(GC)) + 0.5)
        axarr[1].set_xticklabels(range(len(GC)))
        axarr[1].set_yticks(np.arange(5) + 0.5)
        axarr[1].set_yticklabels(range(1, 5 + 1))
        axarr[1].tick_params(axis='both', length=0)

        # Mark nonzeros  将不是0的标记
        # for i in range(len(GC_est)):
        #     for j in range(5):
        #         if GC_est_lag[j, i] > 0.0:
        #             rect = plt.Rectangle((i, j), 1, 1, facecolor='none', edgecolor='green', linewidth=1.0)
        #             axarr[1].add_patch(rect)
        plt.savefig('img/GC_lag3_'+str(i)+'png')
        # plt.show()
# 生成不同类型的数据
def generate_data(type):
    # 两个lag不同的数据
    if type == 1:
        X_np1, beta, GC1 = simulate_var(p=10, T=500, lag=2, seed=0, sparsity=0.2)
        X_np2, beta, GC2 = simulate_var(p=10, T=500, lag=3, seed=1, sparsity=0.2)
        X_np = np.concatenate((X_np1, X_np2), axis=0)
        # GC = np.concatenate((GC1, GC2), axis=0)
        np.savetxt("GC1_seq_true_type=1.txt", GC1, fmt='%d')
        np.savetxt("GC2_seq_true_type=1.txt", GC2, fmt='%d')
    # 两个var因果关系不同的数据
    elif type == 2:
        X_np1, beta, GC1 = simulate_var(p=10, T=500, lag=3, seed=0, sparsity=0.2)
        X_np2, beta, GC2 = simulate_var(p=10, T=500, lag=3, seed=1, sparsity=0.3)
        X_np = np.concatenate((X_np1, X_np2), axis=0)
        # GC = np.concatenate((GC1, GC2), axis=0)
        np.savetxt("GC1_seq_true_type=2.txt", GC1, fmt='%d')
        np.savetxt("GC2_seq_true_type=2.txt", GC2, fmt='%d')
    # var+loz结合数据
    elif type == 3:
        X_np1, beta, GC1 = simulate_var(p=10, T=500, lag=3, seed=0, sparsity=0.2)
        X_np2, GC2 = simulate_lorenz_96(p=10, T=500, F=10)
        X_np = np.concatenate((X_np1, X_np2), axis=0)
        # GC = np.concatenate((GC1, GC2), axis=0)
        np.savetxt("GC1_seq_true_type=3.txt", GC1, fmt='%d')
        np.savetxt("GC2_seq_true_type=3.txt", GC2, fmt='%d')
    X = torch.tensor(X_np, dtype=torch.float32, device=device)
    GC_true = GC1
    return X, GC_true

# 3训练到一定程度后复制一份网络出来，指loss不再下降的时候，修复了bug，并且会将
def GC_var(penalty_sum,seg):
    GC = torch.Tensor()
    weight_norm_list = []
    for penalty_t in penalty_sum:
        # 计算第一个层次的权重矩阵的范数,要沿着第0个和第2个维度进行计算。[100,10,5]
        T = len(penalty_t)
        # 根据序列的划分来分析因果
        weight_norm_list = [torch.norm(penalty_t[i*int(T/seg):(i+1)*int(T/seg)], dim=(0, 2)) for i in range(seg)]
        weight_norm_list = torch.stack(weight_norm_list)

        GC = torch.cat((GC, weight_norm_list.unsqueeze(0)), dim=0)
        # 将范数添加到GC列表中
    GC_seg = GC.transpose(0,1) # [seg,10,10]
    # !!!选择要进行求和的序列
    # GC1 = torch.sum(GC_seg[0:int(seg/2),:,:], dim=0)
    # GC2 = torch.sum(GC_seg[int(seg/2):seg,:,:], dim=0)
    # # 转换为张量
    # array = GC1.cpu().detach().numpy()
    # np.savetxt("GC1.txt", array, fmt='%.5f')
    # array = GC2.cpu().detach().numpy()
    # np.savetxt("GC2.txt", array, fmt='%.5f')
    return GC_seg


def GC_make_lag(penalty_sum,seg):
    # seg = 2
    GC = torch.Tensor()
    for penalty_t in penalty_sum:
        T = len(penalty_t)
        # 根据序列的划分来分析因果
        weight_norm_list = [torch.norm(penalty_t[i * int(T / seg):(i + 1) * int(T / seg)], dim=(0)) for i in
                            range(seg)]
        weight_norm_list = torch.stack(weight_norm_list)

        GC = torch.cat((GC, weight_norm_list.unsqueeze(0)), dim=0)
    return GC.transpose(0,1)

# --------------------设置部分--------------------------
# device = torch.device('cuda')
device = torch.device('cpu')
# GC即为格兰杰因果图，值为1的表示有因果关系 p:矩阵大小 T:数据长度  lag:延迟 seed:随机数生成
lag = 5  # 网络每次用多少个滞后来预测
lag_data = 3 # 在线性模型中，数据是用多少个lag生成的
p = 10  # 序列数量
lam = None  # 对因果权重的惩罚 原0.002
lam = 0.00002 # 对因果权重的惩罚 原0.002
lam_ridge = None  # 对网络中权重的惩罚
# lam_ridge = 0.00001  # 对网络中权重的惩罚
epoch = 0  # 已训练轮次
val_rate = 0.2   # 划分为验证集的百分比
max_seg = 2  # 最多划分为8个序列

print("lag:{0},lam:{1},lam_ridge:{2}".format(lag, lam, lam_ridge))

# 定义模型
if str(device) == "cuda":
    mlp_mixer_list = [MlpMixer(p, lag).cuda(device)]
else:
    mlp_mixer_list = [MlpMixer(p, lag).cpu()]

# --------------------构造输入--------------------------
# X的维度为[1,1000,10][1,采样点，通道数]
# X_np, beta, GC = simulate_var(p=p, T=T, lag=lag_data, seed=0, sparsity=0.2)
# X_np, beta, GC2 = simulate_var(p=p, T=1000, lag=lag_data, seed=1,sparsity=0.2)

# 生成数据
X, GC_true = generate_data(type=3)
# 分割数据  list[seg]tensor[bs,c,lag]

# --------------------准备训练--------------------------
# 定义损失函数
loss_fn = nn.MSELoss(reduction='mean')
# 定义优化器
optimizer_list = []
for model in mlp_mixer_list:
    optimizer_list.append(torch.optim.Adam(model.parameters(), lr=0.001))

# 训练150轮次
# for i in range(120):
best_val_loss = torch.inf
not_best_val_loss_count = 0

seg = 1
train_x, train_y, val_x, val_y = data_segmentation(data=X, lag=5, seg=seg, val_rate=val_rate)
best_GC_seg = None
best_GC_lag = None

# 该序列训练中止的标志
train_flag = [1]*seg
best_val_loss = [torch.inf]*seg
not_best_val_loss_count = [0]*seg
penalty_sum = [torch.Tensor()]*seg
while True:
    # 清空之前的因果矩阵的计算结果
    for mlp_mixer in mlp_mixer_list:
        for block in mlp_mixer.mixer_networks:
            block.penalty_x = torch.Tensor()
    val_loss_sum = 0

    # 对切分数据用不同网络分别进行预测
    for n, mlp_mixer in enumerate(mlp_mixer_list):
        # 如果标志说已经拟合了，就不再训练这个分网络
        if train_flag[n] == 0:
            continue

        list_out = []
        losses = []
        penalty_t = torch.Tensor()
        # ---------------前向预测-------------------
        # 循环每个序列数据专用的网络
        for j in range(0, p):
            # 计算网络预测
            a = train_x[n]
            b =train_y[n][:, j]

            network_output = mlp_mixer.mixer_networks[j](train_x[n]).view(-1)  # train_x[n] = [bs,10,5]
            # 计算一个预测与输出序列之间的损失
            loss_i = loss_fn(network_output, train_y[n][:, j])   # train_y[n][:, j] = [80]
            losses.append(loss_i)
            penalty_t = torch.cat((penalty_t, mlp_mixer.mixer_networks[j].penalty_x.unsqueeze(0)), dim=0)

        # 把因果矩阵拼接起来
        penalty_sum[n] = penalty_t
        # penalty_sum = torch.cat((penalty_sum, penalty_t), dim=1)

        # ---------------计算损失---------------------

        predict_loss = sum(losses)
        sum_loss = predict_loss.clone()

        # 1将惩罚的loss计算
        regularize_loss = None
        if lam is not None and lam > 0:  # 正则化参数 0.002
            regularize_loss = sum([regularize(mixer_networks, lam, "H")
                                   for mixer_networks in mlp_mixer.mixer_networks])
            sum_loss += regularize_loss
        # 2岭回归正则化惩罚
        ridge_regularize_loss = None
        if lam_ridge is not None and lam_ridge > 0:   # 岭回归正则化参数  0.01
            ridge_regularize_loss = sum([ridge_regularize(mixer_networks, lam_ridge)
                               for mixer_networks in mlp_mixer.mixer_networks])
            sum_loss += ridge_regularize_loss #+ regularize_loss  # + ridge_regularize_loss

        # 打印训练的信息
        print("{0}:epoch:{1},predict_loss:{2}".format(datetime.datetime.now().strftime("%H:%M:%S"), epoch, predict_loss), end=' ')
        if regularize_loss is not None:
            print("r1:{0:.5f}".format(regularize_loss), end=' ')
        if ridge_regularize_loss is not None:
            print("r2:{0:.5f}".format(ridge_regularize_loss), end=' ')
        print("sum_loss:{0:.5f}".format(sum_loss))

        # 3. 将损失清零
        optimizer_list[n].zero_grad()
        # 4. 反向传播
        sum_loss.backward()
        # 5. 更新参数
        optimizer_list[n].step()

        # ---------------验证集预测损失，防止过拟合---------------------
        # 循环每个序列数据专用的网络
        val_loss = []
        for j in range(0, p):
            # 计算网络预测
            network_output = mlp_mixer.mixer_networks[j](val_x[n]).view(-1)  # train_x[n] = [bs,10,5]
            # 计算一个预测与输出序列之间的损失
            a = val_y[n][:, j]
            b = network_output
            loss_i = loss_fn(network_output, val_y[n][:, j])  # train_y[n][:, j] = [80]
            val_loss.append(loss_i)
        val_loss_sum += sum(val_loss).item()

        # 如果单个网络在验证集上的loss依旧在下降
        if best_val_loss[n] > sum(val_loss).item():
            best_val_loss[n] = sum(val_loss).item()
            not_best_val_loss_count[n] = 0
        else:
            not_best_val_loss_count[n] += 1
            # 如果连续10轮都没有下降，就不训练这个网络了
            if not_best_val_loss_count[n] > 10:
                train_flag[n]=0

        # best_GC_seg = GC_var(penalty_sum, seg)  # [seg,p=10,p=10]
        print("val_loss=" + str(sum(val_loss).item()))
    # 保存下最好的因果关系


    # 如果所有网络都拟合完毕
    if sum(train_flag) == 0:

        # 到达最大分割数量，展示最终结果和处理
        if seg == max_seg:
            # d = GC_draw_manager(GC_true=GC_true, penalty_sum=penalty_sum, seg=seg)
            # d.update(None)
            penalty_sum = torch.stack(penalty_sum)
            penalty_sum = penalty_sum.transpose(1, 2)
            penalty_sum = penalty_sum.reshape(*[penalty_sum.shape[0] * penalty_sum.shape[1], penalty_sum.shape[2], penalty_sum.shape[3], penalty_sum.shape[4]])
            penalty_sum = penalty_sum.transpose(0, 1)
            best_GC_seg = GC_var(penalty_sum, seg)

            best_GC_lag = GC_make_lag(penalty_sum, seg)
            make_lag_config(GC_true, best_GC_lag[0])    # 第0个模型片段的lag图

            array = best_GC_seg.cpu().detach().numpy()
            with open("GC_seg_pre_re.txt", 'w') as outfile:
                for slice_2d in array:
                    np.savetxt(outfile, slice_2d, fmt='%f', delimiter=',')

            exit()
        else:
            # 否则，分成两个序列继续训练
            seg = seg * 2
            mlp_mixer_list_ = []
            for mlp_mixer in mlp_mixer_list:
                state = mlp_mixer_list[0].state_dict()
                a = MlpMixer(p, lag)
                a.load_state_dict(state)
                b = MlpMixer(p, lag)
                b.load_state_dict(state)
                mlp_mixer_list_.append(a)
                mlp_mixer_list_.append(b)

            mlp_mixer_list = mlp_mixer_list_
            train_x, train_y, val_x, val_y = data_segmentation(data=X, lag=5, seg=seg, val_rate=val_rate)
            # 重新定义优化器
            optimizer_list = []
            for model in mlp_mixer_list:
                optimizer_list.append(torch.optim.Adam(model.parameters(), lr=0.001))
            train_flag = [1] * seg
            best_val_loss = [torch.inf] * seg
            not_best_val_loss_count = [0] * seg
            penalty_sum = [torch.Tensor()] * seg
            # best_val_loss = torch.inf
            # exit()

    print("----------------------------------")

    epoch = epoch + 1

