import torch
import torch.nn as nn
import numpy as np
import pickle
import gc
from tqdm import *
# from mlp_mix_switch import MlpMixer, regularize,ridge_regularize
from GC_draw_manager import draw_GC_ROC_curve
from mlp_mix_best import MlpMixer, regularize, ridge_regularize
from synthetic import simulate_var, simulate_var2, simulate_lorenz_96
import datetime
import matplotlib.pyplot as plt

print("begin")

# 训练一次
def train(sparsity = 0.2, p=10):
    # --------------------设置部分--------------------------
    # device = torch.device('cuda')
    device = torch.device('cpu')
    # GC即为格兰杰因果图，值为1的表示有因果关系 p:矩阵大小 T:数据长度  lag:延迟 seed:随机数生成
    lag = 5  # 每次用多少个滞后来预测
    lag_data = 3 # 数据是用多少个lag生成的
    # p = 10  # 序列数量
    lam = None  # 对因果权重的惩罚 原0.002
    lam = 0.00002 # 对因果权重的惩罚 原0.002
    lam = 0.025 # 对因果权重的惩罚 原0.002
    # lam = 0.025 # 对因果权重的惩罚 原0.002
    # lam = 0.0037 # 对因果权重的惩罚 原0.002
    lam_ridge = None  # 对网络中权重的惩罚
    threshold = 10
    # lam_ridge = 0.01  # 对网络中权重的惩罚
    epoch = 0  # 已训练轮次
    train_data_len = 1000 # 拿来训练的数据有1000个
    T = 1500


    # --------------------可重现部分--------------------------
    # 新建数据
    X_np, beta, GC = simulate_var(p=p, T=T, lag=lag_data, seed=0, sparsity=sparsity)
    # X_np, GC = simulate_lorenz_96(p=p, F=10, T=T)
    # 新建模型
    if str(device) == "cuda":
        mlp_mixer = MlpMixer(p, lag).cuda(device)
    else:
        mlp_mixer = MlpMixer(p, lag).cpu()


    # X的维度为[1,1000,10][1,采样点，通道数]
    X = torch.tensor(X_np, dtype=torch.float32, device=device)

    # --------------------构造输入--------------------------
    train_x = []
    train_input = X[:train_data_len-1, :]  # 取前train_data_len个数预测,少一个是因为那个数是最后一个要预测的值
    # 一共要预测input.shape[0] - lag + 1个数,构造这么多个输入
    for i in range(0, train_input.shape[0] - lag + 1):
        x = train_input[i:i + lag, :]
        x = x.transpose(1, 0)
        train_x.append(x)
    train_x = torch.stack(train_x)  # [995,10,5]
    train_y = X[lag:train_data_len, :]

    # 构造验证集
    val_input = X[999-lag:T, :]   # 不包含最后一项
    val_x = []
    for i in range(0, 504-lag+1):
        x = val_input[i:i + lag, :]
        x = x.transpose(1, 0)
        val_x.append(x)
    val_x = torch.stack(val_x)
    val_y = X[train_data_len:T, :]

    # --------------------准备训练--------------------------
    # 定义损失函数
    loss_fn = nn.MSELoss(reduction='mean')
    # 定义优化器
    optimizer = torch.optim.Adam(mlp_mixer.parameters(), lr=0.001)
    best_val_loss = torch.inf
    best_GC = None
    no_best_count = 0
    # while True:
    for i in range(300):
        # print("lag:{0},lam:{1},lam_ridge:{2}".format(lag, lam, lam_ridge))
        # 清空之前的计算结果
        for block in mlp_mixer.mixer_networks:
            block.penalty_x = torch.Tensor()
        list_out = []
        losses = []
        # ---------------前向预测-------------------
        # 循环每个序列数据专用的网络
        for j in range(0, p):
            # 计算网络预测
            network_output = mlp_mixer.mixer_networks[j](train_x).view(-1)
            # 计算一个预测与输出序列之间的损失
            loss_i = loss_fn(network_output, train_y[:,j])
            losses.append(loss_i)

        # 计算因果关系
        GC1 = mlp_mixer.GC(threshold=threshold)
        GC1_original = mlp_mixer.GC(threshold=0)
        GC_est_lag1 = mlp_mixer.GC(threshold=0, ignore_lag=False)
        # ---------------计算损失---------------------

        predict_loss = sum(losses)
        sum_loss = predict_loss.clone()

        # 1将惩罚的loss计算
        regularize_loss = None
        if lam is not None and lam > 0:  # 正则化参数 0.002
            regularize_loss = sum([regularize(mixer_networks, lam, "H")
                                   for mixer_networks in mlp_mixer.mixer_networks])
            sum_loss += regularize_loss
        # 2岭回归正则化惩罚
        ridge_regularize_loss = None
        if lam_ridge is not None and lam_ridge > 0:   # 岭回归正则化参数  0.01
            ridge_regularize_loss = sum([ridge_regularize(mixer_networks, lam_ridge)
                               for mixer_networks in mlp_mixer.mixer_networks])
            sum_loss += ridge_regularize_loss #+ regularize_loss  # + ridge_regularize_loss

        # 打印训练的信息
        print("{0}:epoch:{1},predict_loss:{2}".format(datetime.datetime.now().strftime("%H:%M:%S"), epoch, predict_loss), end=' ')
        if regularize_loss is not None:
            print("r1:{0:.5f}".format(regularize_loss), end=' ')
        if ridge_regularize_loss is not None:
            print("r2:{0:.5f}".format(ridge_regularize_loss), end=' ')
        print("sum_loss:{0:.5f}".format(sum_loss))


        # 3. 将损失清零
        optimizer.zero_grad()
        # 4. 反向传播
        sum_loss.backward()
        # 5. 更新参数
        optimizer.step()

        # ---------------验证---------------------
        val_loss = []
        loss_f = nn.MSELoss(reduction='mean')
        for j in range(0, p):
            network_output = mlp_mixer.mixer_networks[j](val_x).view(-1)
            loss_i = loss_fn(network_output, val_y[:,j])
            val_loss.append(loss_i)
        if best_val_loss>sum(val_loss).item():
            best_val_loss = sum(val_loss).item()
            best_GC = GC1
            best_GC_original = GC1_original
            best_GC_est_lag = GC_est_lag1
            no_best_count = 0
        else:
            no_best_count += 1
            if no_best_count > 30:

                # make_fig(GC, best_GC)
                # make_lag_config(GC, best_GC, best_GC_est_lag)
                # score = draw_GC_ROC_curve(GC, best_GC_original.cpu().detach().numpy())
                # np.savetxt("GC1.txt", best_GC_original.cpu().detach().numpy(), fmt='%.5f')
                # print("AUROC=" + str(score))
                # print('finish')
                return GC,best_GC,best_GC_original.cpu().detach().numpy()
        print(str(sum(val_loss).item()))
        print("----------------------------------")

        epoch = epoch + 1
    # make_lag_config(GC, best_GC, best_GC_est_lag)
    return GC, best_GC, best_GC_original.cpu().detach().numpy()

def test_best_sparsity_score():
    for sparsity in range(4, 9):
        best_roc_score = 0
    # 循环训练200次，取最好的保存txt下来
        for i in range(200):
            if i % 20 == 0:
                print("第"+str(i)+"轮 sparsity=0."+str(sparsity))
            GC, best_GC, best_GC_original = train(sparsity/10)
            score = draw_GC_ROC_curve(GC, best_GC_original, draw=False)
            if best_roc_score < score:
                print("######" + str(i) + "score=" + str(score) + "######")
                best_roc_score = score
                np.savetxt("sparsity=0."+str(sparsity)+"_true.txt", GC, fmt='%.5f')
                np.savetxt("sparsity=0."+str(sparsity)+"_pre.txt", best_GC_original, fmt='%.5f')
    print("finish")

def test_best_sparsity_score():
    for sparsity in range(4, 9):
        best_roc_score = 0
    # 循环训练200次，取最好的保存txt下来
        for i in range(200):
            if i % 20 == 0:
                print("第"+str(i)+"轮 sparsity=0."+str(sparsity))
            GC, best_GC, best_GC_original = train(sparsity=sparsity/10, p=10)
            score = draw_GC_ROC_curve(GC, best_GC_original, draw=False)
            if best_roc_score < score:
                print("######" + str(i) + "score=" + str(score) + "######")
                best_roc_score = score
                np.savetxt("sparsity=0."+str(sparsity)+"_true.txt", GC, fmt='%.5f')
                np.savetxt("sparsity=0."+str(sparsity)+"_pre.txt", best_GC_original, fmt='%.5f')
    print("finish")

def test_best_p_score():
    for p in range(10, 30, 5):
        best_roc_score = 0
    # 循环训练200次，取最好的保存txt下来
        for i in range(200):
            # if i % 20 == 0:
            if i % 2 == 0:
                print("第"+str(i)+"轮 p="+str(p))
            GC, best_GC, best_GC_original = train(sparsity=0.2, p=p)
            score = draw_GC_ROC_curve(GC, best_GC_original, draw=False)
            # if best_roc_score < score:
            print("######" + str(i) + "score=" + str(score) + "######")
            best_roc_score = score
            np.savetxt("p="+str(p)+"_true.txt", GC, fmt='%.5f')
            np.savetxt("p="+str(p)+"_pre.txt", best_GC_original, fmt='%.5f')
    print("finish")
if __name__ == '__main__':
    # test_best_sparsity_score()
    test_best_p_score()