import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from  sklearn import metrics
from sklearn.cluster import KMeans, SpectralClustering
import pandas as pd
import networkx as nx
import seaborn as sns
from brokenaxes import brokenaxes
from sklearn import preprocessing
import matplotlib
matplotlib.rc("font", family='Microsoft YaHei', weight='bold')
class StandardScaler():
    """
    Standard the input
    """
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    def transform(self, data):
        return (data - self.mean) / self.std
    def inverse_transform(self, data):
        return (data * self.std) + self.mean

def draw_sensor(k_data, sensor_num):
    for i in range(sensor_num):
        plt.figure(i, figsize=(45, 5),dpi=300)
        plt.plot(k_data[24:524, i].reshape(-1), c='k')
        # plt.title("sensor of" + str(i))
        # plt.ylim(0, 90)
        # plt.xlabel('time',fontsize=20)
        # plt.ylabel('value', fontsize=20)
        # plt.xticks(fontsize=16)
        # plt.yticks(fontsize=16)
        # plt.axis('off')
        plt.xticks([])  # 去 x 轴刻度
        plt.yticks([])  # 去 y 轴刻度
        plt.savefig('./visualize/temp/sensor of' + str(i))

def draw_cluster(k_data, cluster_scale=5, predefined_A=None):
    for id in range(10, 11):
        # pre_cluster = cluster[id][:, :10]
        # print((cluster==cluster2).all())
        # print(x[id].shape, k_data.shape)
        # model = KMeans(n_clusters=10, random_state=0)
        # y_pred = model.fit_predict(x[id].T)
        y_pred = SpectralClustering(n_clusters=cluster_scale, affinity='precomputed', random_state=0).fit_predict(predefined_A)
        # y_pred = SpectralClustering(n_clusters=10, affinity='precomputed', random_state=0).fit_predict(adj_mx)
        cluster_list = [[] for i in range(cluster_scale)]
        sensor_list = [[] for j in range(cluster_scale)]
        center_list = []
        for i in range(len(y_pred)):
            # cluster_list[y_pred[i]].append(x[id][:, i].reshape(-1))
            cluster_list[y_pred[i]].append(k_data[24:324, i].reshape(-1))
            sensor_list[y_pred[i]].append(i)
        for i in range(len(cluster_list)):
            average = 0
            for k in cluster_list[i]:
                average += k
            average = average / len(cluster_list[i])
            center_list.append(average)
        # print(len(sensor_list[3]))
        # print(center_list[7].shape)
        for i in range(cluster_scale):

            plt.savefig('./visualize/temp/cluster of' + str(i), dpi=300, bbox_inches='tight')
            plt.figure(1600+i, dpi=300)
            # label1 = [42,6,122,126,32,28,14,47,81,5]
            label1 = []
            for k in range(7):
                if(sensor_list[i][k]==25 or sensor_list[i][k]==23 or sensor_list[i][k]==11):
                    label1.append(sensor_list[i][-1])
                else:
                    label1.append(sensor_list[i][k])
            adj_mx = np.maximum.reduce([predefined_A, predefined_A.T])
            A_map = adj_mx[label1, :]
            A_map = A_map[:, label1]
            # thresh = np.ones_like(A_map) * 0.01
            # mask = A_map <= thresh
            # mask = mask[label, :]
            # mask = mask[:, label]

            sns.heatmap(A_map, label='A', xticklabels=label1, yticklabels=label1, cmap='YlGnBu', cbar=False, linecolor='k')
            plt.title("Adjacent matrix of cluster " + str(i), fontsize=15)
            plt.xticks(fontsize=15)
            plt.yticks(fontsize=15)
            plt.savefig('./visualize/temp/Adjacent matrix of cluster of ' + str(i), dpi=300, bbox_inches='tight')

            for j in range(1, 9):
                # plt.figure(figsize=(10.0, 10.0))
                if (j == 8):
                    plt.figure(i, figsize=(45, 5),dpi=300)
                    plt.subplot(2, 4, j)
                    plt.subplots_adjust(hspace=0.5)
                    plt.plot(center_list[i], c='k', linewidth=4.0)
                    plt.title("representation of cluster" + str(i), fontsize=45)
                    # plt.ylim(0, 90)
                    # plt.xlabel('time')
                    # plt.ylabel('value')
                    plt.xticks([])  # 去 x 轴刻度
                    plt.yticks([])  # 去 y 轴刻度
                else:
                    plt.figure(i, figsize=(45, 5), dpi=300)
                    plt.subplot(2, 4, j)
                    plt.subplots_adjust(hspace=0.5)
                    if j >= len(sensor_list[i]):
                        pass
                    else:
                        if (sensor_list[i][j - 1] == 25 or sensor_list[i][j - 1] == 23 or sensor_list[i][j - 1] == 11):

                            id = sensor_list[i][-1]
                            plt.plot(k_data[24:324, id].T, c='k', linewidth=4.0)
                            plt.title("sensor" + str(id), fontsize=45)
                            # plt.ylim(0, 90)
                            # plt.xlabel('time')
                            # plt.ylabel('value')
                            plt.xticks([])  # 去 x 轴刻度
                            plt.yticks([])  # 去 y 轴刻度
                        else:
                            id = sensor_list[i][j - 1]
                            plt.plot(k_data[24:324, id].T, c='k', linewidth=4.0)
                            plt.title("sensor" + str(id), fontsize=45)
                            # plt.ylim(0, 90)
                            # plt.xlabel('time')
                            # plt.ylabel('value')
                            plt.xticks([])  # 去 x 轴刻度
                            plt.yticks([])  # 去 y 轴刻度


def draw_score(k_data, predefined_A, cluster_size=20):
    inertia_list = []
    print(k_data.shape)
    for k in range(2, cluster_size):
        y_pred = SpectralClustering(n_clusters=k, affinity='precomputed', random_state=0).fit_predict(predefined_A)
        score = metrics.calinski_harabasz_score(k_data.T, y_pred)
        # score = metrics.davies_bouldin_score(k_data.T, y_pred)

        inertia_list.append(score)
    # 绘图
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(range(2, cluster_size), inertia_list, '*-', linewidth=1)
    ax.set_xlabel('k')
    ax.set_ylabel("calinski_harabasz_score")
    ax.set_title('score evaluate')
    plt.show()

def draw_results(y_list, gap =3, draw_length=432, channel_dim=7):
    y_true = y_list[0]
    y_pred = y_list[1]
    y_trans = y_list[2]  #
    y_vae = y_list[3]
    # y_pm = y_list[4]
    for sensor in range(channel_dim):
        if (sensor == 6):
            x = np.linspace(0, draw_length, draw_length)
            y_true_plot = y_true[draw_length*gap:draw_length*gap+draw_length, sensor]
            y_pred_plot = y_pred[draw_length*gap:draw_length*gap+draw_length, sensor]
            y_trans_plot = y_trans[draw_length*gap:draw_length*gap+draw_length, sensor]
            y_hvave_plot = y_vae[draw_length*gap:draw_length*gap+draw_length, sensor]
            # y_pm_plot = y_pm[sensor, 864*gap:864*gap+432]
            plt.figure(sensor, figsize=(12, 6))
            plt.xlabel('time', fontsize=54)
            plt.ylabel('value', fontsize=54)
            plt.grid(True, linestyle="--", alpha=1)
            plt.xticks(fontsize=54)
            plt.yticks(fontsize=54)
            # plt.ylim(12, 72)
            plt.plot(x, y_true_plot.T, label='ground truth', c='r',linewidth=4)
            plt.plot(x, y_pred_plot.T, label='ours', c='black',linewidth=4)
            plt.plot(x, y_trans_plot.T, label='Transformer', linewidth=4)
            plt.plot(x, y_hvave_plot.T, label='Transformer+HVAE', linewidth=4)
            # plt.plot(x, y_pm_plot.T, label='PM-MemNet')
            plt.legend(fontsize=30)
            plt.show()


from scipy import io

if __name__=="__main__":
    '''绘制输入数据'''
    # df_raw = pd.read_csv('./data/ETT/ETTh2.csv')
    # cols = list(df_raw.columns)
    # cols.remove('date')
    # df_raw = df_raw[cols]
    # length = 900
    # x = np.linspace(0, length, length)
    #
    # y_true_plot = df_raw.values[1500:1500+length, 6]
    # plt.figure(figsize=(100, 5))
    # plt.grid(True, linestyle="--", alpha=1)
    # plt.xticks([])  # 去 x 轴刻度
    # plt.yticks([])  # 去 y 轴刻度
    # plt.plot(x, y_true_plot, c='k', linewidth=4)
    # plt.show()

    # '''绘制预测结果'''
    # y_true = np.load('./results/ground_truth.npy')
    # y_ours = np.load('./results/our_sota.npy')
    # y_trans = np.load('./results/Transformer.npy')
    # y_vae = np.load('./results/HVAE.npy')
    # # y_pm = np.load('./results/pm-memnet.npy')
    #
    #
    # y_list = [y_true, y_ours, y_trans, y_vae]
    # for k in range(20):
    #     draw_results(y_list, gap=k, channel_dim=7) #5
    # io.savemat('./results/ground_truth.mat', {'results': y_true})
    # io.savemat('./results/our_sota.mat', {'results': y_ours})
    # io.savemat('./results/mtgnn.mat', {'results': y_mtgnn})
    # io.savemat('./results/graphwavenet.mat', {'results': y_gwn})
    # io.savemat('./results/pm-memnet.mat', {'results': y_pm})


    # # '''绘制gamma折线图'''
    # plt.figure(figsize=(8, 6))
    # gamma = ['0', '0.25', '0.5', '0.75', '1']
    # n4 = [0.483, 0.474, 0.476, 0.477, 0.480]
    # n3 = [0.441, 0.436, 0.440, 0.444, 0.442]
    # n2 = [0.410, 0.404, 0.405, 0.407, 0.408]
    # n1 = [0.372, 0.371, 0.366, 0.371, 0.373]
    # # n3 = [14.26, 14.10, 14.32, 14.50, 15.50, 17.78]
    # # n2 = [14.37, 14.25, 14.42, 14.65, 15.67, 17.94]
    # # n1 = [14.57, 14.42, 14.63, 14.68, 15.77, 18.02]
    # plt.xlabel('Gamma', fontsize=18)
    # plt.ylabel('MSE', fontsize=18)
    # plt.xticks(fontsize=18)
    # plt.yticks(fontsize=18)
    # plt.ylim(0.36, 0.49)
    # plt.grid(True, linestyle="--", alpha=1)
    # plt.plot(gamma, n1, label='predict_len 96', c='r', marker='o', markersize=8, linewidth=2.5)
    # plt.plot(gamma, n2, label='predict_len 192', c='b', marker='x', markersize=8, linewidth=2.5)
    # plt.plot(gamma, n3, label='predict_len 336', c='orange', marker='v', markersize=8, linewidth=2.5)
    # # plt.plot(gamma, n4, label='predict_len 720', c='g', marker='*', markersize=15, linewidth=4.0)
    # # plt.plot(gamma, y2, label='NAVER', linestyle='--', c='black', marker='x')
    # plt.legend(fontsize=18)
    # plt.show()


    # '''绘制N1折线图'''
    # plt.figure(figsize=(8, 6))
    # ax = plt.gca()
    # gamma = ['32', '64', '128', '256', '512']
    #
    # n3 = [0.452, 0.448, 0.444, 0.443, 0.446]
    # n2 = [0.416, 0.412, 0.411, 0.410, 0.412]
    # n1 = [0.382, 0.376, 0.374, 0.372, 0.374]
    #
    # # n3 = [14.29, 14.26, 14.25, 14.30, 14.50]
    # # n2 = [14.48, 14.44, 14.37, 14.43, 14.62]
    # # n1 = [14.58, 14.51, 14.49, 14.62, 14.73]
    # plt.xlabel('knowledge memory size', fontsize=18)
    # plt.ylabel('MSE', fontsize=18)
    # plt.xticks(fontsize=18)
    # plt.yticks(fontsize=18)
    # plt.grid(True, linestyle="--", alpha=1)
    # plt.ylim(0.365, 0.495)
    # # plt.plot(gamma, y1, label='METR-LA', c='r',marker='o')
    # plt.plot(gamma, n1, label='predict_len 96', c='r', marker='o', markersize=8, linewidth=2.5)
    # plt.plot(gamma, n2, label='predict_len 192', c='b', marker='x', markersize=8, linewidth=2.5)
    # plt.plot(gamma, n3, label='predict_len 336', c='orange', marker='v', markersize=8, linewidth=2.5)
    # plt.legend(fontsize=18)
    # plt.show()

    # '''绘制Layer折线图'''
    # plt.figure(figsize=(8, 6))
    # gamma = ['0', '7', '14', '21', '28']
    # n3 = [0.484, 0.476, 0.474, 0.473, 0.471]
    # n2 = [0.398, 0.392, 0.390, 0.390, 0.390]
    # n1 = [0.294, 0.289, 0.288, 0.287, 0.288]
    # plt.xlabel('episodic memory size', fontsize=18)
    # plt.ylabel('MSE', fontsize=18)
    # plt.xticks(fontsize=18)
    # plt.yticks(fontsize=18)
    # plt.grid(True, linestyle="--", alpha=1)
    # plt.ylim(0.20, 0.495)
    # # plt.plot(gamma, y1, label='METR-LA', c='r',marker='o')
    # plt.plot(gamma, n1, label='predict_len 96', c='r', marker='o', markersize=8, linewidth=2.5)
    # plt.plot(gamma, n2, label='predict_len 192', c='b', marker='x', markersize=8, linewidth=2.5)
    # plt.plot(gamma, n3, label='predict_len 336', c='orange', marker='v', markersize=8, linewidth=2.5)
    # plt.legend(fontsize=18)
    # plt.show()

    # '''绘制柱状图ETTm2'''
    # plt.figure(figsize=(8, 6))
    # gamma = ['96', '192', '336', '720']
    # x_tick = np.arange(4)
    # n3 = [0.360, 0.456, 0.700, 1.390]
    # n2 = [0.338, 0.363, 0.453, 0.582]
    # n1 = [0.257, 0.302, 0.343, 0.407]
    # plt.xticks(x_tick + 0.30, gamma)
    # plt.xlabel('predict length', fontsize=18)
    # plt.ylabel('MAE', fontsize=18)
    # plt.xticks(fontsize=18)
    # plt.yticks(fontsize=18)
    #
    # # plt.plot(gamma, y1, label='METR-LA', c='r',marker='o')
    #
    # plt.bar(x_tick+0.0, n1, label='SPM-Net', width=0.3, edgecolor='#33a3dc', color='#f6f5ec', hatch='/', linewidth=2)
    # plt.bar(x_tick+0.3, n2, label='Linear+GCN', width=0.3, edgecolor='#f26522', color='#f6f5ec', hatch='-', linewidth=2)
    # plt.bar(x_tick+0.6, n3, label='Linear+CNN', width=0.3, edgecolor='#FFB90F', color='#f6f5ec', hatch='+', linewidth=2)
    # # for x1, y1 in enumerate(n1):
    # #     plt.text(x1, y1, y1, ha='center', fontsize=16)
    # # for x2, y2 in enumerate(n2):
    # #     plt.text(x2 + 0.3, y2 , y2, ha='center', fontsize=16)
    # # for x3, y3 in enumerate(n3):
    # #     plt.text(x3 + 0.6, y3 , y3, ha='center', fontsize=16)
    # plt.grid(linestyle='--', linewidth=0.3, color='gray', alpha=0.7)
    # plt.legend(fontsize=18, loc="best")
    # plt.savefig('./bar_example.pdf', bbox_inches='tight', dpi=1200)
    # plt.show()

    '''绘制柱状图ETTm2'''
    # plt.figure(figsize=(8, 6))
    # gamma = ['96', '192', '336', '720']
    # x_tick = np.arange(4)
    # n2 = [1.160, 1.492, 1.991, 2.119]
    # n1 = [0.637, 1.352, 1.661, 1.920]
    # plt.xticks(x_tick + 0.30, gamma)
    # plt.xlabel('predict length', fontsize=18)
    # plt.ylabel('MAE', fontsize=18)
    # plt.xticks(fontsize=18)
    # plt.yticks(fontsize=18)
    #
    # # plt.plot(gamma, y1, label='METR-LA', c='r',marker='o')
    #
    # plt.bar(x_tick+0.15, n1, label='Transformer+ours', width=0.3, edgecolor='#009ad6', color='white', hatch='/', linewidth=2)
    # plt.bar(x_tick+0.45, n2, label='Transformer', width=0.3, edgecolor='#00ae9d', color='white', hatch='-', linewidth=2)
    # plt.grid(linestyle='--', linewidth=0.3, color='gray', alpha=0.7)
    # plt.legend(fontsize=18, loc="best")
    # plt.savefig('./bar_example.pdf', bbox_inches='tight', dpi=1200)
    # plt.show()

    # '''绘制柱状图ETTh2'''
    plt.figure(figsize=(8, 6))
    gamma = ['96', '192', '336', '720']
    x_tick = np.arange(4)
    n3 = [0.609, 1.15, 1.08, 1.]
    n2 = [0.451, 0.591, 0.627, 0.697]
    n1 = [0.351, 0.410, 0.454, 0.560]
    plt.xticks(x_tick + 0.30, gamma)
    plt.xlabel('predict length', fontsize=18)
    plt.ylabel('MAE', fontsize=18)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)

    # plt.plot(gamma, y1, label='METR-LA', c='r',marker='o')

    plt.bar(x_tick+0.0, n1, label='SPM-Net', width=0.3, edgecolor='#33a3dc', color='#f6f5ec', hatch='/', linewidth=2)
    plt.bar(x_tick+0.3, n2, label='Linear+GCN', width=0.3, edgecolor='#f26522', color='#f6f5ec', hatch='-', linewidth=2)
    plt.bar(x_tick+0.6, n3, label='Linear+CNN', width=0.3, edgecolor='#FFB90F', color='#f6f5ec', hatch='+', linewidth=2)
    # for x1, y1 in enumerate(n1):
    #     plt.text(x1, y1, y1, ha='center', fontsize=16)
    # for x2, y2 in enumerate(n2):
    #     plt.text(x2 + 0.3, y2 , y2, ha='center', fontsize=16)
    # for x3, y3 in enumerate(n3):
    #     plt.text(x3 + 0.6, y3 , y3, ha='center', fontsize=16)
    plt.grid(linestyle='--', linewidth=0.3, color='gray', alpha=0.7)
    plt.legend(fontsize=18, loc="best")
    plt.savefig('./bar_example.pdf', bbox_inches='tight', dpi=1200)
    plt.show()

    '''绘制柱状图ETTh2'''
    # plt.figure(figsize=(8, 6))
    # gamma = ['96', '192', '336', '720']
    # x_tick = np.arange(4)
    # n2 = [1.026, 1.338, 1.599, 1.35]
    # n1 = [0.978, 1.225, 1.209, 1.221]
    # plt.xticks(x_tick + 0.30, gamma)
    # plt.xlabel('predict length', fontsize=18)
    # plt.ylabel('MAE', fontsize=18)
    # plt.xticks(fontsize=18)
    # plt.yticks(fontsize=18)
    #
    # # plt.plot(gamma, y1, label='METR-LA', c='r',marker='o')
    #
    # plt.bar(x_tick+0.15, n1, label='Transformer+ours', width=0.3, edgecolor='#009ad6', color='white', hatch='/', linewidth=2)
    # plt.bar(x_tick+0.45, n2, label='Transformer', width=0.3, edgecolor='#00ae9d', color='white', hatch='-', linewidth=2)
    # plt.grid(linestyle='--', linewidth=0.3, color='gray', alpha=0.7)
    # plt.legend(fontsize=18, loc="best")
    # plt.savefig('./bar_example.pdf', bbox_inches='tight', dpi=1200)
    # plt.show()


    '''横向'''
    # plt.figure(figsize=(8, 6))
    # gamma = ['0.05', '0.3', '0.6', '0.9', '1']
    # y_tick = np.arange(5)
    # n1 = [0.449, 0.446, 0.448, 0.449, 0.449]
    # plt.yticks(y_tick, gamma)
    # plt.ylabel('predict length', fontsize=18)
    # plt.xlabel('MAE', fontsize=18)
    # plt.xticks(fontsize=18)
    # plt.yticks(fontsize=18)
    # # plt.plot(gamma, y1, label='METR-LA', c='r',marker='o')
    #
    # plt.barh(y_tick, n1, label='SPM-Net', edgecolor='#F0F8FF', color='#87CEFA', linewidth=2, height=0.4)
    #
    # # for x1, y1 in enumerate(n1):
    # #     plt.text(x1, y1, y1, ha='center', fontsize=16)
    # # for x2, y2 in enumerate(n2):
    # #     plt.text(x2 + 0.3, y2 , y2, ha='center', fontsize=16)
    # # for x3, y3 in enumerate(n3):
    # #     plt.text(x3 + 0.6, y3 , y3, ha='center', fontsize=16)
    # plt.grid(linestyle='--', linewidth=0.3, color='gray', alpha=0.4)
    # plt.legend(fontsize=18, loc="best")
    # plt.savefig('./bar_example.pdf', bbox_inches='tight', dpi=1200)
    # plt.show()

    # candidate queue
    # plt.figure(figsize=(8, 6))
    # ax = plt.gca()
    # gamma = ['0.05', '0.3', '0.6', '0.9', '1']
    #
    # # n2 = [0.495, 0.492, 0.493, 0.494, 0.496]
    # # n1 = [0.449, 0.446, 0.448, 0.449, 0.449]
    #
    # n2 = [0.456, 0.450, 0.453, 0.454, 0.454]
    # n1 = [0.372, 0.367, 0.370, 0.372, 0.373]
    # plt.xlabel('Candidate Queue size', fontsize=18)
    # plt.ylabel('MAE', fontsize=18)
    # plt.xticks(fontsize=18)
    # plt.yticks(fontsize=18)
    # plt.grid(True, linestyle="--", alpha=1)
    # plt.ylim(0.36, 0.47)
    # # plt.plot(gamma, y1, label='METR-LA', c='r',marker='o')
    # plt.plot(gamma, n1, label='predict_len 96', c='r', marker='o', markersize=8, linewidth=2.5)
    # plt.plot(gamma, n2, label='predict_len 192', c='b', marker='x', markersize=8, linewidth=2.5)
    # plt.legend(fontsize=18)
    # plt.show()
