import torch
import torch.nn as nn
import numpy as np
# from mlp_mix_switch import MlpMixer, regularize,ridge_regularize
from mlp_mix_best import MlpMixer, regularize,ridge_regularize
# from mlp_mix_residual import MlpMixer, regularize,ridge_regularize
from synthetic import simulate_var, data_segmentation, simulate_lorenz_96
import datetime
from GC_draw_manager import GC_draw_manager
import matplotlib.pyplot as plt

print("begin")
# 直接划分序列，然后每个序列分别训练
# 生成不同类型的数据
def generate_data(type):
    # 两个lag不同的数据
    if type == 1:
        X_np1, beta, GC1 = simulate_var(p=10, T=500, lag=2, seed=0, sparsity=0.2)
        X_np2, beta, GC2 = simulate_var(p=10, T=500, lag=3, seed=1, sparsity=0.2)
        X_np = np.concatenate((X_np1, X_np2), axis=0)
        # GC = np.concatenate((GC1, GC2), axis=0)
        np.savetxt("GC1_seq_true_type=1.txt", GC1, fmt='%d')
        np.savetxt("GC2_seq_true_type=1.txt", GC2, fmt='%d')
    # 两个var因果关系不同的数据
    elif type == 2:
        X_np1, beta, GC1 = simulate_var(p=10, T=500, lag=3, seed=0, sparsity=0.2)
        X_np2, beta, GC2 = simulate_var(p=10, T=500, lag=3, seed=1, sparsity=0.3)
        X_np = np.concatenate((X_np1, X_np2), axis=0)
        # GC = np.concatenate((GC1, GC2), axis=0)
        np.savetxt("GC1_seq_true_type=2.txt", GC1, fmt='%d')
        np.savetxt("GC2_seq_true_type=2.txt", GC2, fmt='%d')
    # var+loz结合数据
    elif type == 3:
        X_np1, beta, GC1 = simulate_var(p=10, T=500, lag=3, seed=0, sparsity=0.2)
        X_np2, GC2 = simulate_lorenz_96(p=10, F=500, T=10)
        X_np = np.concatenate((X_np1, X_np2), axis=0)
        # GC = np.concatenate((GC1, GC2), axis=0)
        np.savetxt("GC1_seq_true_type=3.txt", GC1, fmt='%d')
        np.savetxt("GC2_seq_true_type=3.txt", GC2, fmt='%d')
    X = torch.tensor(X_np, dtype=torch.float32, device=device)
    GC_true = GC1
    return X, GC_true

# 4 每个网络分别算loss，分别训练
def GC_var(penalty_sum,seg):
    GC = torch.Tensor()
    weight_norm_list = []
    for penalty_t in penalty_sum:
        # 计算第一个层次的权重矩阵的范数,要沿着第0个和第2个维度进行计算。[100,10,5]
        T = len(penalty_t)
        # 根据序列的划分来分析因果
        weight_norm_list = [torch.norm(penalty_t[i*int(T/seg):(i+1)*int(T/seg)], dim=(0, 2)) for i in range(seg)]
        weight_norm_list = torch.stack(weight_norm_list)

        GC = torch.cat((GC, weight_norm_list.unsqueeze(0)), dim=0)
        # 将范数添加到GC列表中
    GC_seg = GC.transpose(0,1) # [seg,10,10]
    # !!!选择要进行求和的序列
    # GC1 = torch.sum(GC_seg[0:int(seg/2),:,:], dim=0)
    # GC2 = torch.sum(GC_seg[int(seg/2):seg,:,:], dim=0)
    # # 转换为张量
    # array = GC1.cpu().detach().numpy()
    # np.savetxt("GC1.txt", array, fmt='%.5f')
    # array = GC2.cpu().detach().numpy()
    # np.savetxt("GC2.txt", array, fmt='%.5f')
    return GC_seg

# --------------------设置部分--------------------------
# device = torch.device('cuda')
device = torch.device('cpu')
# GC即为格兰杰因果图，值为1的表示有因果关系 p:矩阵大小 T:数据长度  lag:延迟 seed:随机数生成
lag = 5  # 网络每次用多少个滞后来预测
lag_data = 3 # 在线性模型中，数据是用多少个lag生成的
p = 10  # 序列数量
lam = None  # 对因果权重的惩罚 原0.002
lam = 0.00002 # 对因果权重的惩罚 原0.002
# lam = 0.025
lam_ridge = None  # 对网络中权重的惩罚
# lam_ridge = 0.00001  # 对网络中权重的惩罚
epoch = 0  # 已训练轮次
val_rate = 0.2   # 划分为验证集的百分比
max_seg = 2  # 最多划分为8个序列
print("lag:{0},lam:{1},lam_ridge:{2}".format(lag, lam, lam_ridge))


# 定义模型
if str(device) == "cuda":
    mlp_mixer_list = [MlpMixer(p, lag).cuda(device) for _ in range(max_seg)]
else:
    mlp_mixer_list = [MlpMixer(p, lag).cpu() for _ in range(max_seg)]

# --------------------构造输入--------------------------
# X的维度为[1,1000,10][1,采样点，通道数]
# X_np, beta, GC = simulate_var(p=p, T=T, lag=lag_data, seed=0, sparsity=0.2)
# X_np, beta, GC2 = simulate_var(p=p, T=1000, lag=lag_data, seed=1,sparsity=0.2)

# 生成数据
X, GC_true = generate_data(type=3)
# 分割数据  list[seg]tensor[bs,c,lag]

# --------------------准备训练--------------------------
# 定义损失函数
loss_fn = nn.MSELoss(reduction='mean')
# 定义优化器
optimizer_list = []
for model in mlp_mixer_list:
    optimizer_list.append(torch.optim.Adam(model.parameters(), lr=0.001))

# 训练150轮次
# for i in range(120):

seg = max_seg
train_x, train_y, val_x, val_y = data_segmentation(data=X, lag=5, seg=seg, val_rate=val_rate)
best_GC_seg = None
train_flag = [1]*seg
best_val_loss = [torch.inf]*seg
not_best_val_loss_count = [0]*seg


while True:
    # 清空之前的因果矩阵的计算结果
    for mlp_mixer in mlp_mixer_list:
        for block in mlp_mixer.mixer_networks:
            block.penalty_x = torch.Tensor()
    penalty_sum = torch.Tensor() # [10,1000,5,10]
    val_loss_sum = 0

    # 对切分数据用不同网络分别进行预测
    for n, mlp_mixer in enumerate(mlp_mixer_list):
        # 各种原因已经拟合了，就不再训练这个网络
        if train_flag[n]==0:
            continue
        list_out = []
        losses = []
        penalty_t = torch.Tensor()
        # ---------------前向预测-------------------
        # 循环每个序列数据专用的网络
        for j in range(0, p):
            # 计算网络预测
            a = train_x[n]
            b =train_y[n][:, j]

            network_output = mlp_mixer.mixer_networks[j](train_x[n]).view(-1)  # train_x[n] = [bs,10,5]
            # 计算一个预测与输出序列之间的损失
            loss_i = loss_fn(network_output, train_y[n][:, j])   # train_y[n][:, j] = [80]
            losses.append(loss_i)
            penalty_t = torch.cat((penalty_t, mlp_mixer.mixer_networks[j].penalty_x.unsqueeze(0)), dim=0)

        # 把因果矩阵拼接起来
        penalty_sum = torch.cat((penalty_sum, penalty_t), dim=1)
        # ---------------计算损失---------------------

        predict_loss = sum(losses)
        sum_loss = predict_loss.clone()

        # 1将惩罚的loss计算
        regularize_loss = None
        if lam is not None and lam > 0:  # 正则化参数 0.002
            regularize_loss = sum([regularize(mixer_networks, lam, "H")
                                   for mixer_networks in mlp_mixer.mixer_networks])
            sum_loss += regularize_loss
        # 2岭回归正则化惩罚
        ridge_regularize_loss = None
        if lam_ridge is not None and lam_ridge > 0:   # 岭回归正则化参数  0.01
            ridge_regularize_loss = sum([ridge_regularize(mixer_networks, lam_ridge)
                               for mixer_networks in mlp_mixer.mixer_networks])
            sum_loss += ridge_regularize_loss #+ regularize_loss  # + ridge_regularize_loss

        # 打印训练的信息
        print("{0}:epoch:{1},predict_loss:{2}".format(datetime.datetime.now().strftime("%H:%M:%S"), epoch, predict_loss), end=' ')
        if regularize_loss is not None:
            print("r1:{0:.5f}".format(regularize_loss), end=' ')
        if ridge_regularize_loss is not None:
            print("r2:{0:.5f}".format(ridge_regularize_loss), end=' ')
        print("sum_loss:{0:.5f}".format(sum_loss))

        # 3. 将损失清零
        optimizer_list[n].zero_grad()
        # 4. 反向传播
        sum_loss.backward()
        # 5. 更新参数
        optimizer_list[n].step()

        # ---------------验证集预测损失，防止过拟合---------------------
        # 循环每个序列数据专用的网络
        val_loss = []
        for j in range(0, p):
            # 计算网络预测
            network_output = mlp_mixer.mixer_networks[j](val_x[n]).view(-1)  # train_x[n] = [bs,10,5]
            # 计算一个预测与输出序列之间的损失
            a = val_y[n][:, j]
            b = network_output
            # 计算损失
            loss_i = loss_fn(network_output, val_y[n][:, j])  # train_y[n][:, j] = [80]
            val_loss.append(loss_i)
        # 如果在验证集上的loss依旧在下降
        if best_val_loss[n] > sum(val_loss).item():
            # 计算因果关系并保存成文件
            best_GC_seg = GC_var(penalty_sum, seg)
            best_val_loss[n] = sum(val_loss).item()
            not_best_val_loss_count[n] = 0
        else:
            not_best_val_loss_count[n] += 1
            if not_best_val_loss_count[n] > 5:
                train_flag[n]=0
        print("val_loss=" + str(sum(val_loss).item()))
        # val_loss_sum += sum(val_loss).item()
    # val_loss_sum = val_loss_sum / seg
    # print("val_loss="+str(val_loss_sum))
    # 验证集准确率不再下降，训练到头的标志
    if sum(train_flag) == 0:
        # 达到最大序列划分数量，展示最终结果和处理
        if seg == max_seg:
            array = best_GC_seg.cpu().detach().numpy()
            with open("GC_seg_pre_sep.txt", 'w') as outfile:
                for slice_2d in array:
                    np.savetxt(outfile, slice_2d, fmt='%f', delimiter=',')
            exit()
            # out了
            d = GC_draw_manager(GC_true=GC_true, penalty_sum=penalty_sum, seg=seg)
            d.update(None)
    print("----------------------------------")

    epoch = epoch + 1



