import os
import numpy as np
import scipy.sparse as sp

import heapq
from sklearn.metrics.pairwise import euclidean_distances
import matplotlib.pyplot as plt
from torch_geometric.utils import to_dense_adj
import scipy.io as sio
# import tensorflow as  tf
# import torch
from torch_geometric.datasets import Amazon, Coauthor, Planetoid, Reddit


def getData(dataset):
    if dataset.lower() in ['pubmed', 'citeseer', 'cora', 'wiki']:
        data = sio.loadmat('{}.mat'.format(dataset))
        feature = data['fea']
        if sp.issparse(feature):
            feature = np.asarray(feature.todense())

        adj = data['W']
        gnd = data['gnd']
        gnd = gnd.T
        gnd = gnd - 1
        gnd = gnd[0, :]
        k = len(np.unique(gnd))
        adj = sp.coo_matrix(adj)

    elif dataset in ["cs", "physics", "computers", "photo", "Reddit"]:
        dataset_data = load_dataset(dataset, transform=None)
        data = dataset_data[0]
        k = dataset_data.num_classes
        gnd = data.y.cpu().data.numpy()
        feature = data.x.cpu().data.numpy()
        adj = to_dense_adj(data.edge_index, edge_attr=data.edge_attr).squeeze().data.numpy()

    return adj, gnd, k, feature


def load_dataset(dataset, transform=None):
    if dataset.lower() in ["cora", "citeseer", "pubmed"]:
        path = os.path.join("~/.datasets", "Plantoid")
        dataset = Planetoid(path, dataset.lower(), transform=transform)
    elif dataset.lower() in ["cs", "physics"]:
        path = os.path.join("~/.datasets", "Coauthor", dataset.lower())
        dataset = Coauthor(path, dataset.lower(), transform=transform)
    elif dataset.lower() in ["computers", "photo"]:
        path = os.path.join("~/.datasets", "Amazon", dataset.lower())
        dataset = Amazon(path, dataset.lower(), transform=transform)
    elif dataset.lower() in ["reddit"]:
        path = os.path.join("~/.datasets", "Reddit")
        dataset = Reddit(path, transform=transform)
    else:
        print("Dataset not supported!")
        assert False
    return dataset


def make_dir(dirName):
    # Create a target directory & all intermediate
    # directories if they don't exists
    if not os.path.exists(dirName):
        os.makedirs(dirName, exist_ok=True)
        print("[INFO] Directory ", dirName, " created")
    else:
        print("[INFO] Directory ", dirName, " already exists")


def normalize_adj(adj, type='sym'):
    """Symmetrically normalize adjacency matrix."""
    if type == 'sym':
        adj = sp.coo_matrix(adj)
        rowsum = np.array(adj.sum(1))
        # d_inv_sqrt = np.power(rowsum, -0.5)
        # d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        # return adj*d_inv_sqrt*d_inv_sqrt.flatten()
        d_inv_sqrt = np.power(rowsum, -0.5).flatten()
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_inv_sqrt[np.isnan(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
        return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo()
    elif type == 'rw':
        rowsum = np.array(adj.sum(1))
        d_inv = np.power(rowsum, -1.0).flatten()
        d_inv[np.isinf(d_inv)] = 0.
        d_inv[np.isnan(d_inv)] = 0.
        d_mat_inv = sp.diags(d_inv)
        adj_normalized = d_mat_inv.dot(adj)
        return adj_normalized


def preprocess_adj(adj, type='sym', loop=True):
    """Preprocessing of adjacency matrix for simple GCN
    model and conversion to tuple representation."""
    if loop:
        adj = adj + sp.eye(adj.shape[0])
    adj_normalized = normalize_adj(adj, type=type)
    return adj_normalized


def to_onehot(prelabel):
    k = len(np.unique(prelabel))
    # print(k, prelabel.max(), len(prelabel))

    label = np.zeros([prelabel.shape[0], k])
    label[range(prelabel.shape[0]), prelabel] = 1
    label = label.T
    return label


def square_dist(prelabel, feature):
    if sp.issparse(feature):
        feature = feature.todense()
    feature = np.array(feature)

    onehot = to_onehot(prelabel)

    m, n = onehot.shape
    count = onehot.sum(1).reshape(m, 1)
    count[count == 0] = 1

    mean = onehot.dot(feature) / count
    a2 = (onehot.dot(feature * feature) / count).sum(1)
    pdist2 = np.array(a2 + a2.T - 2 * mean.dot(mean.T))

    intra_dist = pdist2.trace()
    inter_dist = pdist2.sum() - intra_dist
    intra_dist /= m
    inter_dist /= m * (m - 1)
    return intra_dist


def dist(prelabel, feature):
    k = len(np.unique(prelabel))
    intra_dist = 0

    for i in range(k):
        Data_i = feature[np.where(prelabel == i)]

        Dis = euclidean_distances(Data_i, Data_i)
        n_i = Data_i.shape[0]
        if n_i == 0 or n_i == 1:
            intra_dist = intra_dist
        else:
            intra_dist = intra_dist + 1 / k * 1 / (n_i * (n_i - 1)) * sum(sum(Dis))

    return intra_dist


def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)


def random_generate_masks(true_lab, node_list=None,
                          num_for_train=20, type_generate="guided",
                          num_train=None):
    if type_generate == "random": # or num_train is not None:
        train_mask, test_mask = random_split(true_lab, node_list=node_list,
                                             num_lab=num_for_train,
                                             num_train=num_train)
    elif type_generate == "guided":
        train_mask, test_mask = guided_random_split(true_lab, node_list=node_list,
                                                    num_lab=num_for_train)

    return train_mask, test_mask


def random_split(all_true_labels, node_list=None,
                 num_lab=20, num_train=None):
    uniques = len(np.unique(all_true_labels))
    num = len(all_true_labels)
    if node_list is None:
        node_list = num
    else:
        # this is assymetric, will return el in l1 that are not in l2
        node_list = set(range(0, num)) - set(node_list)
        node_list = list(node_list)

    # print(node_list)

    if num_train is None:
        num_train = int(num_lab * uniques)
    else:
        num_train = int(num_train*uniques)
    train_idx = np.random.choice(node_list, num_train, replace=False)
    # train_idx = np.random.randint(0, len(all_true_labels), num_train)
    train_mask = np.zeros(num, dtype=bool)
    train_mask[train_idx] = True

    test_mask = np.zeros(num, dtype=bool)
    test_mask[train_mask == True] = False
    test_mask[train_mask == False] = True
    print('checking number of training nodes: {}'.format(train_mask[train_mask == True].shape))
    print('checking number of testing nodes: {}'.format(test_mask[test_mask == True].shape))
    # print('checking number of training nodes for one class: {}'.format(test_mask[all_true_labels[test_mask==True]==0]))
    return train_mask, test_mask


def guided_random_split(all_true_labels, node_list=None,
                        num_lab=20):
    "Create guided random splits"
    uniques = len(np.unique(all_true_labels))
    num = len(all_true_labels)
    num_train = int(num_lab * uniques)
    num_train_per_class = int(num_lab)
    num_test = num - num_train

    print('num_train: {}, num_test: {}, all: {}'.format(num_train, num_test,
                                                        num_test + num_train_per_class * uniques))
    print('effective training rate is {}% with {}'
          ' nodes per class'.format(uniques * num_train_per_class / num * 100, num_train_per_class))

    train_mask = np.zeros(num, dtype=bool)
    update_mask = np.zeros(num, dtype=bool)

    for k in range(uniques):
        check = np.where(all_true_labels == k)
        check = check[0]
        # if node_list is not None:
        #     # this is assymetric, will return el in l1 that are not in l2
        #     node_list = set(range(0, num)) - set(node_list)
        #     node_list = list(node_list)
        #
        #     # # if intersection was needed
        #     # check = [i for i in check if i in node_list]
        #     # check = list(set(check).intersection(node_list))

        # print(idx.shape, idx[1:10])
        # print(check[0].shape, check[0][1:10])
        if len(check) >= num_train_per_class:
            idx = np.random.choice(check, size=num_train_per_class, replace=False)
        else:
            idx = np.random.choice(check, size=num_train_per_class, replace=True)
        update_mask[idx] = True
        # print(idx.shape)
        # exit(0)
        train_mask = np.logical_or(train_mask, update_mask)

    test_mask = update_mask
    test_mask[train_mask == True] = False
    test_mask[train_mask == False] = True
    print('checking number of training nodes: {}'.format(train_mask[train_mask == True].shape))
    print('checking number of testing nodes: {}'.format(test_mask[test_mask == True].shape))
    # print('checking number of training nodes for one class: {}'.format(test_mask[all_true_labels[test_mask==True]==0]))
    return train_mask, test_mask


def unsupervised_guided_split(kmeans, predict_labels, u, all_true_labels,
                              type_split='guided-mixed', num_lab=1,
                              min_num=None, type_random='guided'):
    # predict_labels = to_onehot(predict_labels).T
    centroids = kmeans.cluster_centers_
    euc_res = 1 / euclidean_distances(u, centroids)

    "Create cluster guided splits"
    uniques = len(np.unique(all_true_labels))
    num = len(all_true_labels)
    num_train_per_class = int(num_lab)

    if type_split == 'guided-min-random':
        if num_train_per_class <= 1:
            min_num = num_train_per_class

        min_num = Find_k(euc_res, kk=min_num, axis=0, type_of_k='max')
        minnodes = np.argwhere(min_num == True).squeeze().tolist()

        random_num = num_train_per_class - len(minnodes)//uniques
        train_mask, _ = random_generate_masks(all_true_labels, node_list=minnodes,
                                              num_for_train=random_num,
                                              type_generate=type_random)
        train_mask = min_num | train_mask

    elif type_split == 'guided-min':
        train_mask = Find_k(euc_res, kk=num_train_per_class, axis=0, type_of_k='max')

    elif type_split == 'guided-min-max':
        if num_train_per_class <= 1:
            min_num = num_train_per_class
        else:
            if min_num is None:
                min_num = num_train_per_class // 2
        max_num = num_train_per_class - min_num
        max_num = Find_k(euc_res, kk=max_num, axis=0, type_of_k='min')

        maxnodes = np.argwhere(max_num == True).squeeze().tolist()
        min_num = num_train_per_class - len(maxnodes) // uniques

        euc_res[max_num > 0] = 0 # making sure random nodes selected are not selected again
        min_num = Find_k(euc_res, kk=min_num, axis=0, type_of_k='max')

        train_mask = min_num + max_num

        minnodes = np.argwhere(min_num == True).squeeze().tolist()
        maxnondes = np.argwhere(max_num == True).squeeze().tolist()
        nodesl = minnodes + maxnondes

        random_num = num_train_per_class - len(nodesl)//uniques
        if random_num:
            temp_train_mask, _ = random_generate_masks(all_true_labels, node_list=nodesl,
                                                       num_train=random_num,
                                                       type_generate=type_random)

            train_mask = temp_train_mask | train_mask

    elif type_split == 'guided-min-all':
        train_mask = Find_k_all(euc_res, kk=num_train_per_class*uniques)

    elif type_split == 'guided-min-all-random':
        if num_train_per_class <= 1:
            min_num = num_train_per_class

        min_num = Find_k_all(euc_res, kk=min_num*uniques)
        minnodes = np.argwhere(min_num == True).squeeze().tolist()

        random_num = num_train_per_class - len(minnodes) // uniques
        train_mask, _ = random_generate_masks(all_true_labels,
                                              num_for_train=random_num,
                                              node_list=minnodes,
                                              type_generate=type_random)
        train_mask = train_mask | min_num

    elif type_split == 'guided-min-max-all':

        if num_train_per_class <= 1:
            min_num = num_train_per_class
        else:
            if min_num is None:
                min_num = num_train_per_class // 2
        max_num = num_train_per_class - min_num
        print(f'min_num {min_num}, max_num {max_num}, num_train_per_cluster {num_train_per_class}')
        max_num = Find_k_all(1./euc_res, kk=max_num*uniques)

        maxnodes = np.argwhere(max_num == True).squeeze().tolist()
        min_num = num_train_per_class*uniques - len(maxnodes)

        euc_res[max_num > 0] = 0 # making sure random nodes selected are not selected again
        min_num = Find_k_all(euc_res, kk=min_num)

        train_mask = min_num + max_num

        minnodes = np.argwhere(min_num == True).squeeze().tolist()
        maxnondes = np.argwhere(max_num == True).squeeze().tolist()
        nodesl = minnodes + maxnondes

        random_num = num_train_per_class - len(nodesl)//uniques
        if random_num:
            temp_train_mask, _ = random_generate_masks(all_true_labels, node_list=nodesl,
                                                  num_train=random_num,
                                                  type_generate=type_random)
            train_mask = temp_train_mask | train_mask

    gen_train = train_mask[train_mask == True].shape[0]
    request_train = num_train_per_class*uniques

    if gen_train != request_train:
        print(f"[ERROR] number of generated training labels is {gen_train},"
              f"while number of requested is {request_train},"
              f"which {num_train_per_class} nodes per class")
        if gen_train < request_train:
            idx_correct = request_train - gen_train
            print(f"[CORRECTING] number of random addition is {idx_correct}")
            nodelist = np.argwhere(train_mask == True).squeeze().tolist()
            nodelist = set(range(0, num)) - set(nodelist)
            nodelist = list(nodelist)
            idx_correct = np.random.choice(nodelist, size=idx_correct, replace=False)
            train_mask[idx_correct] = True
        else:
            print(f"[CORRECTION] not implemented for this case where gen_train: {gen_train} >"
                  f" request_train: {request_train}")
            exit(0)

    gen_train = train_mask[train_mask == True].shape[0]
    request_train = num_train_per_class * uniques

    if gen_train != request_train:
        print(f"[CORRECTION FAILED] number of generated training labels is {gen_train},"
              f"while number of requested is {request_train},"
              f"which {num_train_per_class} nodes per class")
        # exit(0)

    test_mask = np.zeros(num, dtype=bool)
    test_mask[train_mask == True] = False
    test_mask[train_mask == False] = True

    print('checking number of training nodes: {}'.format(train_mask[train_mask == True].shape))
    print('checking number of testing nodes: {}'.format(test_mask[test_mask == True].shape))

    return train_mask, test_mask


def keep_top_k(arr, kk):
    smallest = heapq.nlargest(kk, arr)[-1]  # find the top 3 and use the smallest as cut off
    arr[arr < smallest] = 0  # replace anything lower than the cut off with 0
    return arr


def keep_min_k(arr, kk):
    smallest = heapq.nsmallest(kk, arr)[-1]  # find the top 3 and use the smallest as cut off
    arr[arr > smallest] = 0  # replace anything lower than the cut off with 0
    return arr


def k_largest_index_argpartition_v1(a, k):
    idx = np.argpartition(-a.ravel(), k)[:k]
    return np.column_stack(np.unravel_index(idx, a.shape))


def k_largest_index_argpartition_v2(a, k):
    idx = np.argpartition(a.ravel(), a.size - k)[-k:]
    return np.column_stack(np.unravel_index(idx, a.shape))


def k_largest_index_argsort(a, k):
    idx = np.argsort(a.ravel())[:-k - 1:-1]
    return np.column_stack(np.unravel_index(idx, a.shape))


def Find_k_all(arr, kk, type_of_k='min'):
    mask = np.zeros(arr.shape[0], dtype=bool)

    idx = k_largest_index_argsort(arr, kk)[:, 0]
    mask[idx] = True

    # print(f"type_of_k: {type_of_k}, idx: {idx}")
    # exit(0)

    return mask


def Find_k(arr, kk, axis=-1, type_of_k='min'):
    indexes = arr.argsort(axis=axis)
    mask = np.zeros(arr.shape[0], dtype=bool)

    if type_of_k == 'max':
        indexes = indexes[arr.shape[axis] - kk:]
        mask[indexes] = True
        # mask = np.ones_like(arr)
        # np.put_along_axis(mask, indexes[:, :arr.shape[axis] - kk], 0, axis=axis)
    else:
        indexes = indexes[:kk]
        mask[indexes] = True
        # mask = np.zeros_like(arr)
        # np.put_along_axis(mask, indexes[:, :kk], 1, axis=axis)

    # print(f"type_of_k: {type_of_k}, idx: {indexes}")

    return mask


def distance_to_centroids(means_model, feats, predict_labels, dist_type=1, keep=2, return_prob=True):
    # if distance to all centers
    if dist_type == 1:
        centroids = means_model.cluster_centers_
        euc_res = euclidean_distances(feats, centroids)
        # check probabilities
        # print(euc_res.shape, euc_res[0])
        euc_res = Find_k(1 / euc_res, keep, axis=1, type_of_k='max')
        # print(euc_res.shape, euc_res[0])
        # # probabilities = ss.softmax(euc_res, axis=1) # pushes too far
        row_sums = (euc_res).sum(axis=1)
        probabilities = (euc_res) / row_sums[:, np.newaxis]
        predict_labels_check = np.argmax(probabilities, axis=1)
        # # print(euc_res.shape, euc_res[0], (predict_labels-predict_labels_check).sum())
        # print(probabilities.shape, probabilities[0])
        # exit(0)

        # # Check normalizations
        # normlaized_res = normalize(euc_res, norm='l2', axis=1)
        # predict_labels_check = np.argmin(normlaized_res, axis=1)
        # print(euc_res.shape, euc_res[0], (predict_labels-predict_labels_check).sum())
        # print(normlaized_res.shape, normlaized_res[0])
        # exit(0)

        # # normalize the result
        # normlaized_res = (1 / euc_res) / ((1 / euc_res).sum())

        # # convert to list and sort it
        # normlaized_res_list = normlaized_res.tolist()
        # sorted_res = sorted(normlaized_res, reverse=True)

        # # get the nearset cluster
        # nearest_cluster = []
        # for i in sorted_res[:10]:
        #     nearest_cluster.append(normlaized_res_list.index(i))
        if return_prob:
            return probabilities

        return euc_res

    # if only distance to current center
    elif dist_type == 2:
        centroids = means_model.cluster_centers_
        NumClusters = len(np.unique(predict_labels))
        euc_res = []
        for i in NumClusters:
            dataInCluster = feats[np.where(predict_labels == i), :]
            distaces = euclidean_distances(dataInCluster - centroids[i])
            # print(distaces.shape, dataInCluster.shape)
            euc_res.append(distaces)

        return euc_res


def ProcessKNN(mtest, KG):
    PKG = np.zeros_like(KG)
    PKG[mtest, mtest] = 1
    PKG = PKG * KG

    plt.subplot(121)
    plt.imshow(KG)
    plt.title(f"Knn Graph has {np.count_nonzero(KG)} edges")

    plt.subplot(122)
    plt.imshow(PKG)
    plt.title(f"Modified has {np.count_nonzero(PKG)} edges")
    plt.show()

    return PKG


def plot_sparce_vs_dense_cluster_adj(prelabs, adjacency, KG=None,
                                     use_knn=False, fsize=12):
    # # cluster graph
    cluster_graph = to_onehot(prelabs).T
    cluster_graph = cluster_graph.dot(cluster_graph.T)
    print('cluster_graph info: ', cluster_graph.shape,
          np.count_nonzero(cluster_graph), np.count_nonzero(cluster_graph))

    if (KG is not None) and use_knn:
        opt_cluster_graph = KG
    else:
        opt_cluster_graph = agc_labels_to_optimized_graph(prelabs,
                                                          (adjacency.shape[0], adjacency.shape[1]))

        opt_cluster_graph = opt_cluster_graph + opt_cluster_graph.T + np.eye(adjacency.shape[0])
    
    print('symm opt_cluster_graph info: ', opt_cluster_graph.shape, np.count_nonzero(opt_cluster_graph),
          np.count_nonzero(preprocess_adj(opt_cluster_graph).toarray()))

    print('adjacency info', adjacency.shape, np.count_nonzero(adjacency.toarray()),
          np.count_nonzero(preprocess_adj(adjacency).toarray()))

    plt.subplot(131)
    plt.title('Cluster graph', fontsize=fsize)
    plt.imshow(cluster_graph)
    plt.axis('off')

    plt.subplot(132)
    plt.title('Adjacency', fontsize=fsize)
    plt.imshow(adjacency.toarray())
    plt.axis('off')

    plt.subplot(133)
    plt.title('Optimized cluster graph', fontsize=fsize)
    plt.imshow(opt_cluster_graph)
    plt.axis('off')
    # plt.colorbar()
    plt.show()
    # exit(0)


def agc_labels_to_optimized_graph(prelabs, adj_shape):
    kk = len(np.unique(prelabs))
    opt_cluster_graph = np.zeros((adj_shape[0], adj_shape[1]))

    for i in range(kk):
        idx = np.where(prelabs == i)[0]
        # # print(idx)
        # # exit(0)
        # print(i, ': ', idx.shape, adjnormal[idx, idx].shape)
        opt_cluster_graph[idx[0], idx] = 1

        print('non-symm opt_cluster_graph info:', opt_cluster_graph.shape,
              np.count_nonzero(opt_cluster_graph),
              np.count_nonzero(preprocess_adj(opt_cluster_graph).toarray()))

    return sp.coo_matrix(opt_cluster_graph + opt_cluster_graph.T + np.eye(adj_shape[0]))