import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import pandas as pd
import os
import numpy as np
from utils.dataloader import load_data
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mutual_info_score
from sklearn.neighbors import KernelDensity
from sklearn.feature_selection import mutual_info_regression

import warnings
from tqdm import tqdm

warnings.filterwarnings('ignore')


ROOT_PATH = 'D:/WorkSpace_Python/MultiPatchTST/MultiPatchTST_supervised'

class GraphConv(nn.Module):
    def __init__(self, corr, high_correlated_count, node_num:int=21, d_node:int=48, top_k:int=8,
                 tanh_alpha:float=3, device:str="cuda:0") -> None:
        super(GraphConv, self).__init__()
        self.correlated_matrix = torch.from_numpy(corr)
        self.correlated_count = torch.from_numpy(high_correlated_count)
        self.node_num = node_num
        self.d_node = d_node
        self.top_k = top_k
        self.tanh_alpha = tanh_alpha
        self.device = device
        
        # 为节点生成两套嵌入
        self.g_embed1 = nn.Embedding(node_num, d_node)
        self.g_embed2 = nn.Embedding(node_num, d_node)
        self.linear1 = nn.Linear(d_node, d_node)
        self.linear2 = nn.Linear(d_node, d_node)

        # 外部辅助设置
        self.diag_indices = torch.arange(self.correlated_count.size(0))
        self.correlated_matrix[self.diag_indices, self.diag_indices] = 0
        self.correlated_count = F.relu(self.correlated_count-1)
        self.correlated_matrix = F.normalize(self.correlated_matrix, p=2, dim=-1)


    def forward(self ,id_list):
        node_embed1 = self.g_embed1(id_list)
        node_embed2 = self.g_embed2(id_list)
        node_embed1 = self.linear1(node_embed1)
        node_embed2 = self.linear2(node_embed2)
        node_embed1 = torch.tanh(self.tanh_alpha*node_embed1)
        node_embed2 = torch.tanh(self.tanh_alpha*node_embed2)
        A0 = torch.triu(torch.mm(node_embed1, node_embed1.transpose(1,0)), diagonal=-1) + \
            torch.tril(torch.mm(node_embed2, node_embed2.transpose(1,0)), diagonal=-1)
        # 将对角线元素设置为 0
        A0[self.diag_indices, self.diag_indices] = 0
        A0 = F.normalize(A0, p=2, dim=-1)
        A0 = A0 + self.correlated_matrix
        A = F.relu(torch.tanh(self.tanh_alpha * A0))
        mask = torch.zeros_like(A)

        # 计算每个节点与其他节点的top-k关联
        nonzero_mask = A != 0
        nonzero_counts = nonzero_mask.sum(dim=1)
        for i, count in enumerate(nonzero_counts):
            topk = min(self.correlated_count[i], self.top_k)
            if count >= topk:
                _, topk_indices = A[i].topk(topk)
                mask[i][topk_indices] = 1
            else:
                mask[i][nonzero_mask[i]] = 1
        A = A * mask
        A[A != 0] = 1
        return A, A0


class SimpleLearn(nn.Module):
    def __init__(self, coor, high_correlated_count, seq_len:int=336, pred_len:int=96, d_model:int=128,
                 top_k=8, device:str="cuda:0") -> None:
        super(SimpleLearn, self).__init__()
        self.gc = GraphConv(coor, high_correlated_count, node_num=coor.shape[0], top_k=8)
        self.id_list = torch.arange(coor.shape[0])

    def forward(self, x):
        A, A0 = model(self.id_list)

def MutualInformation(args):
    # 读取所有列数据
    df = pd.read_csv(os.path.join(ROOT_PATH, args.root_path, args.data_path))
    columns = list(df.columns)
    columns.remove('date')
    df = df[columns][:int(0.7*df.shape[0])]
    data = df.values.T
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    segment_length = 336
    num_segments = data.shape[1] // segment_length
    segments = np.split(data[:, :num_segments*segment_length], num_segments, axis=1)
    num_segments = len(segments)
    mi_matrix = np.zeros((num_segments, data.shape[0], data.shape[0]))
    for i in range(num_segments):
        mi_value = np.zeros((data.shape[0], data.shape[0]))
        for j in tqdm(range(data.shape[0])):
            for k in range(j, data.shape[0]):
                # hist_x, _ = np.histogram(segments[i][j], bins=20, density=False)
                # hist_y, _ = np.histogram(segments[i][k], bins=20, density=False)
                X = segments[i][j].copy()
                Y = segments[i][k].copy()
                # kde_X = KernelDensity().fit(X.reshape(-1, 1))
                # kde_Y = KernelDensity().fit(Y.reshape(-1, 1))
                mi = mutual_info_regression(X.reshape(-1, 1), Y)[0]
                mi_value[j, k] = mi
                mi = mutual_info_regression(Y.reshape(-1, 1), X)[0]
                mi_value[k, j] = mi
        # max_v = np.max(mi_value)
        # min_v = np.min(mi_value)
        # # 对数据进行标准化
        # normalized_data = (mi_value - min_v) / (max_v - min_v)
        # mi_value = normalized_data.T
        mi_matrix[i, :, :] = mi_value
        # 绘制热力图
        plt.figure(figsize=(7.2, 7.2))
        plt.imshow(mi_value, cmap='coolwarm', interpolation='nearest')
        plt.colorbar()

        # 设置坐标刻度
        xticks = np.arange(0, mi_value.shape[1], 100)
        yticks = np.arange(0, mi_value.shape[0], 100)
        plt.xticks(xticks)
        plt.yticks(yticks)
        plt.title('Correlation Heatmap of Time Series Data')
        plt.savefig(ROOT_PATH+f'/storage/MI{i}.png')
        plt.clf()
        exit()


def draw_origin(args):
    # 读取所有列数据
    df = pd.read_csv(os.path.join(ROOT_PATH, args.root_path, args.data_path))
    columns = list(df.columns)
    columns.remove('date')
    df = df[columns][:336]
    # 皮尔逊相关性
    corr = df.corr()
    high_correlated_count = np.sum(corr > 0.8) - 1
    # 设置坐标刻度
    plt.figure(figsize=(10.8, 7.2))
    xticks = np.arange(high_correlated_count.shape[0])
    plt.bar(xticks, high_correlated_count[xticks])
    plt.savefig(ROOT_PATH+f'/{args.store_path}/weather_counts_0_336.png')
    plt.clf()

    # 绘制热力图
    plt.figure(figsize=(7.2, 7.2))
    plt.imshow(corr, cmap='coolwarm', interpolation='nearest')
    plt.colorbar()

    # 设置坐标刻度
    xticks = np.arange(0, corr.shape[1], 100)
    yticks = np.arange(0, corr.shape[0], 100)
    plt.xticks(xticks)
    plt.yticks(yticks)
    plt.title('Correlation Heatmap of Time Series Data')
    plt.savefig(ROOT_PATH+f'/{args.store_path}/weather_corr_0_336.png')
    plt.clf()
    return np.array(corr), np.array(high_correlated_count)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Autoformer & Transformer family for Time Series Forecasting')
    parser.add_argument('--store_path',         type=str,   default='storage/lab1', help='model id')
    parser.add_argument('--dataset_name',       type=str,   default='weather', help='model id')
    parser.add_argument('--root_path',          type=str,   default='data/weather', help='root path of the data file')
    parser.add_argument('--data_path',          type=str,   default='weather.csv', help='data file')
    parser.add_argument('--dataset_type',       type=str,   default='custom', help='dataset type')

    args = parser.parse_args()
    args.target='OT'
    args.enc_in=21
    args.seq_len, args.label_len, args.pred_len = 336, 48, 96
    args.num_workers = 16
    args.use_gcn=False

    if not os.path.exists(f'args.store_path'):
        os.mkdir(f'args.store_path')
    # draw_origin(args)
    coor, high_correlated_count = draw_origin(args)
    data, _, _ = load_data(args)
    model = SimpleLearn(coor, high_correlated_count)
    
    for epoch in 5:
        output, A0, A = model(x)

        index='0'
        for a in [A0, A]:
            a = a.detach().cpu().numpy()
            plt.figure(figsize=(7.2, 7.2))
            plt.imshow(a, cmap='coolwarm', interpolation='nearest')
            plt.colorbar()
            # 设置坐标刻度
            xticks = np.arange(0, a.shape[1], 100)
            yticks = np.arange(0, a.shape[0], 100)
            plt.xticks(xticks)
            plt.yticks(yticks)
            plt.title('Original Embedding A Heatmap')
            plt.savefig(ROOT_PATH+f'/{args.store_path}/weather_A{index}_5000_5336.png')
            plt.clf()
            index=''