import os
import pickle

import numpy as np

from sklearn.model_selection import train_test_split
from scipy.io import loadmat

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from NeuroGraph.datasets import NeuroGraphDynamic

from data_utils import GraphDataset


# Create a contrastive loss function
def contrastive_loss(embeddings, labels):
    # Compute pairwise similarities between embeddings
    pairwise_similarities = torch.matmul(embeddings, embeddings.T)# / temperature

    # Compute positive and negative similarities
    positive_similarities = pairwise_similarities[labels.unsqueeze(1) == labels.unsqueeze(0)]
    negative_similarities = pairwise_similarities[labels.unsqueeze(1) != labels.unsqueeze(0)]
    log_softmax_pos = F.logsigmoid(positive_similarities)
    log_softmax_neg = F.logsigmoid(-negative_similarities)
    # Compute NCE loss
    loss = -(log_softmax_pos.mean() + log_softmax_neg.mean())

    return loss

def read_graph_matrices(data_name, sub_type=None):
    func_dict = {'cog_state': read_cog_data, 'slim': read_slim_data, 'DynHCP': read_hcp_data}
    if data_name == 'DynHCP':
        full_name = data_name + sub_type
        adj_matrices, feat_matrices, ind_labels, obj_labels = func_dict[data_name](full_name)
    elif data_name == 'slim':
        data_folder = 'F:/CGSL_series/CGSL_uni/data/{}/'.format(data_name)
        adj_matrices, feat_matrices, ind_labels, obj_labels = func_dict[data_name](data_folder)
    elif data_name == 'cog_state':
        data_folder = 'F:/CGSL_series/CGSL_uni/data/{}/'.format(data_name)
        adj_matrices, feat_matrices, ind_labels, obj_labels = func_dict[data_name](data_folder)
    else:
        raise ValueError('Invalid data name')
    nets_index = list(range(len(feat_matrices)))
    adj_matrices = np.array(adj_matrices, dtype=np.float32)
    feat_matrices = np.array(feat_matrices, dtype=np.float32)
    # print(ind_labels)
    ind_labels = np.array(ind_labels, dtype=np.int32)
    # print(ind_labels)
    obj_labels = np.array(obj_labels, dtype=np.int32)
    nets_index = np.array(nets_index, dtype=np.int32)
    return adj_matrices, feat_matrices, nets_index, ind_labels, obj_labels

def read_hcp_data(data_name):
    ind_labels, obj_labels = [], []
    adj_matrices, feat_matrices = [], []

    processed_dir = 'F:/CGSL_series/CGSL_uni/data/processed/{}'.format(data_name)
    if os.path.exists(processed_dir) == False:
        os.makedirs(processed_dir, exist_ok=True)
    if len(os.listdir(processed_dir)) == 0:
        dataset = NeuroGraphDynamic('F:/CGSL_series/CGSL_uni/data/', name=data_name)
        for i, batch in enumerate(dataset.dataset):
            # if i < 5:
            num_graphs = batch.num_graphs
            max_nodes = batch.num_features
            choices = np.random.choice(range(num_graphs), 10)
            for j in choices:
                # 提取属于当前图的节点的掩码
                node_mask = batch.batch == j
                node_features = batch.x[node_mask]

                # 提取边的索引
                edge_mask = node_mask[batch.edge_index[0]] & node_mask[batch.edge_index[1]]
                edge_index = batch.edge_index[:, edge_mask]

                # 提取目标值（如果有的话）
                target = batch.y[j]
                adj_matrix = torch.zeros((max_nodes, max_nodes))
                adj_matrix[edge_index[0]-j*max_nodes, edge_index[1]-j*max_nodes] = 1  # 填充邻接矩阵
                adj_matrix[edge_index[1]-j*max_nodes, edge_index[0]-j*max_nodes] = 1

                feat_matrices.append(node_features)
                ind_labels.append(i)
                obj_labels.append(target)
                adj_matrices.append(adj_matrix)
        pickle.dump(adj_matrices, open(processed_dir+'/adj_matrices.pkl', 'wb'))
        pickle.dump(feat_matrices, open(processed_dir+'/feat_matrices.pkl', 'wb'))
        pickle.dump(ind_labels, open(processed_dir+'/ind_labels.pkl', 'wb'))
        pickle.dump(obj_labels, open(processed_dir+'/obj_labels.pkl', 'wb'))
    else:
        adj_matrices = pickle.load(open(processed_dir+'/adj_matrices.pkl', 'rb'))
        feat_matrices = pickle.load(open(processed_dir+'/feat_matrices.pkl', 'rb'))
        ind_labels = pickle.load(open(processed_dir+'/ind_labels.pkl', 'rb'))
        obj_labels = pickle.load(open(processed_dir+'/obj_labels.pkl', 'rb'))
    return adj_matrices, feat_matrices, ind_labels, obj_labels

def read_cog_data(data_folder):
    state_dict = {'EO':0, 'EC':1, 'Ma':2, 'Me':3, 'Mu':4}
    ind_labels, obj_labels = [], []
    adj_matrices, feat_matrices = [], []

    for file_name in os.listdir(data_folder+'feat/'):
        if file_name.endswith('.pkl'):
            feat_path = os.path.join(data_folder+'feat/', file_name)
            choice = np.random.choice(np.arange(100), 10)

            with open(feat_path, 'rb') as file:
                # Load the feat matrix from the pickle file
                data = pickle.load(file)
                for i in choice:
                    feat_matrices.append(data[i])

                    # Extract the id from the file name
                    flags = file_name.split('_')

                    # Append the id to the labels list
                    ind_labels.append(int(flags[0].split('b')[1])-1)
                    obj_labels.append(int(state_dict[flags[2].split('.')[0]]))
            adj_path = os.path.join(data_folder+'graph/', file_name)
            with open(adj_path, 'rb') as file:
                # Load the network matrix from the pickle file
                data = pickle.load(file)
                for j in choice:
                    mat = np.array(data[j], dtype=np.float32)
                    mat = np.where(mat != 0, 1, 0)
                    mat.astype(np.int32)
                    adj_matrices.append(mat)
    return adj_matrices, feat_matrices, ind_labels, obj_labels

def read_slim_data(data_folder):
    ind_labels, obj_labels = [], []
    adj_matrices, feat_matrices = [], []
    label_dicts = {'male':0, 'female':1}

    for file_name in os.listdir(data_folder+'feat/'):
        if file_name.endswith('.pkl'):
            feat_path = os.path.join(data_folder+'feat/', file_name)
            with open(feat_path, 'rb') as file:
                # Load the feat matrix from the pickle file
                data = pickle.load(file)
                for i in range(len(data)):
                    feat_matrices.append(data[i])
                    ind_labels.append(int(file_name.split('_')[0]))
                    obj_labels.append(label_dicts[file_name.split('_')[1].split('.')[0]])
            adj_path = os.path.join(data_folder+'graph/', file_name)
            with open(adj_path, 'rb') as file:
                # Load the network matrix from the pickle file
                data = pickle.load(file)
                for i in range(len(data)):
                    mat = keep_topalpha(data[i])
                    adj_matrices.append(mat)
    return adj_matrices, feat_matrices, ind_labels, obj_labels

def graph_dataset_loader(feat_matrices, net_index, ind_labels, state_labels, batch_size):
    feat_matrices = [torch.tensor(matrix, dtype=torch.float32) for matrix in feat_matrices]
    dataset = GraphDataset(feat_matrices, net_index, ind_labels, state_labels)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    return loader

def keep_topalpha(matrix, alpha=0.1):
    # 扁平化矩阵并排序
    flattened = matrix.flatten()
    sorted_indices = np.argsort(flattened)

    # 计算前10%元素的索引
    top_10_percent_index = int(len(flattened) * (1-alpha))

    # 获取需要设置为1的阈值
    threshold = flattened[sorted_indices[top_10_percent_index]]

    # 将大于或等于阈值的元素设置为1，其余设置为0
    matrix = np.where(matrix >= threshold, 1, 0)
    return matrix

def compute_corr(matrices):
    corr_matrices = []
    for matrix in matrices:
        corr_matrix = np.corrcoef(matrix)
        corr_matrices.append(corr_matrix)
    return corr_matrices

def data_set_split_by_ind(X, ind, y, rate, random_state=42):
    ind_list = np.unique(ind)
    X_train, X_test, y_train1, y_test1, y_train2, y_test2 = [], [], [], [], [], []
    for i in ind_list:
        X_ind = X[ind==i]
        y1_ind = ind[ind==i]
        y2_ind = y[ind==i]
        X_train_ind, X_test_ind, y_train1_ind, y_test1_ind, y_train2_ind, y_test2_ind = train_test_split(X_ind, y1_ind, y2_ind, test_size=rate, random_state=random_state)
        X_train.append(X_train_ind)
        X_test.append(X_test_ind)
        y_train1.append(y_train1_ind)
        y_test1.append(y_test1_ind)
        y_train2.append(y_train2_ind)
        y_test2.append(y_test2_ind)
    X_train = np.concatenate(X_train)
    X_test = np.concatenate(X_test)
    y_train1 = np.concatenate(y_train1)
    y_test1 = np.concatenate(y_test1)
    y_train2 = np.concatenate(y_train2)
    y_test2 = np.concatenate(y_test2)
    val_size = int(len(X_train)/10)
    X_train, X_val, y_train1, y_val1, y_train2, y_val2 = train_test_split(X_train, y_train1, y_train2, test_size=val_size, random_state=random_state)
    return X_train, X_val, X_test, y_train1, y_val1, y_test1, y_train2, y_val2, y_test2

def data_set_split_across_ind(X, ind, y, rate, random_state=42):
    ind_list = np.unique(ind)
    ind_test = np.random.choice(ind_list, size=int(len(ind_list)*rate), replace=False)
    ind_val = np.random.choice(np.setdiff1d(ind_list, ind_test), size=int(len(ind_list)*rate/2), replace=False)
    ind_train = np.setdiff1d(np.setdiff1d(ind_list, ind_test), ind_val)
    X_train, X_val, X_test, y_train1, y_val1, y_test1, y_train2, y_val2, y_test2 = [], [], [], [], [], [], [], [], []
    for i in ind_train:
        X_ind = X[ind == i]
        y_ind1 = ind[ind == i]
        y_ind2 = y[ind == i]
        X_train.append(X_ind)
        y_train1.append(y_ind1)
        y_train2.append(y_ind2)

    for i in ind_val:
        X_ind = X[ind == i]
        y_ind1 = ind[ind == i]
        y_ind2 = y[ind == i]
        X_val.append(X_ind)
        y_val1.append(y_ind1)
        y_val2.append(y_ind2)

    for i in ind_test:
        X_ind = X[ind == i]
        y_ind1 = ind[ind == i]
        y_ind2 = y[ind == i]
        X_test.append(X_ind)
        y_test1.append(y_ind1)
        y_test2.append(y_ind2)

    X_train = np.concatenate(X_train)
    X_val = np.concatenate(X_val)
    X_test = np.concatenate(X_test)
    y_train1 = np.concatenate(y_train1)
    y_val1 = np.concatenate(y_val1)
    y_test1 = np.concatenate(y_test1)
    y_train2 = np.concatenate(y_train2)
    y_val2 = np.concatenate(y_val2)
    y_test2 = np.concatenate(y_test2)

    return X_train, X_val, X_test, y_train1, y_val1, y_test1, y_train2, y_val2, y_test2

if __name__ == '__main__':
    test = read_hcp_data('F:/CGSL_series/CGSL_uni/data/DynHCPActivity/')