import scipy.io as sio
import time
import os
import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
import scipy.sparse as sp
from sklearn.cluster import KMeans, k_means
from metrics import clustering_metrics
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.preprocessing import normalize
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score
from sklearn.neighbors import kneighbors_graph
from utils import (preprocess_adj, ProcessKNN, square_dist,
                   to_onehot, agc_labels_to_optimized_graph,
                   distance_to_centroids)


def AGC(adjacency, feats_AGC, gnd_lab, tol=0.0, pow=12,
        k=7, num_runs=1, norm='l2', num_NN=20):
    print(f'--- AGC, power: {pow}, k: {k}, num_runs: {num_runs} ---')

    intraD = [np.inf]
    predict_labels = 0
    tmp_means = 0
    tmp_u = 0
    u = 0
    tmp_lab = 0
    kmeans_agc = 0
    best_p = pow
    filter = preprocess_adj(adj=adjacency)
    a = 0.5
    b = 0.5
    c = 1
    d = 0
    filter_a = (1 - a) * sp.eye(filter.shape[0]) + a * filter
    filter_b = b*(sp.eye(filter.shape[0]) - filter)
    filter = c*filter_a + d*filter_b

    for p in range(1, pow + 1):
        feats_AGC = filter.dot(feats_AGC)
        predict_labels = tmp_lab
        tmp_means = kmeans_agc
        tmp_u = u

        if norm in ['l1', 'l2', 'max']:
            feats_AGC = normalize(feats_AGC, norm=norm, axis=1)

        u, s, v = sp.linalg.svds(feats_AGC, k=k, which='LM')

        kmeans_agc = KMeans(n_clusters=k, random_state=0, init='k-means++', n_init=10).fit(u)
        tmp_lab = kmeans_agc.predict(u)

        intraD.append(square_dist(tmp_lab, feats_AGC))
        # intraD.append(dist(tmp_lab, feats_AGC))

        if intraD[p - 1] - intraD[p] <= tol:
            best_p = p - 1
            if p == 1:
                predict_labels = tmp_lab
            break

    Knn_G = kneighbors_graph(tmp_u, num_NN, mode='connectivity', include_self=True).toarray()
    # plt.imshow(Knn_G)
    # plt.show()

    try:
        cm = clustering_metrics(gnd_lab, predict_labels)
        accuracy_AGC, nmi_AGC, f1_sc_AGC = cm.evaluationClusterModelFromLabel()
    except TypeError:
        accuracy_AGC, nmi_AGC, f1_sc_AGC = 0, 0, 0
    print(f'--- AGC, best_power: {best_p}, number of non empty clusters: {len(np.unique(predict_labels))}  ---')
    # exit(0)
    return accuracy_AGC, nmi_AGC, f1_sc_AGC, tmp_means, tmp_u, predict_labels, best_p, Knn_G


def lp(adjacency, train_mask, mask_test, gnd_lab, num_propagations, p=0.6, alpha=0.4):
    print(f' --- LabelProp  --- ')
    adjnorm = preprocess_adj(adjacency)

    idx = mask_test
    y = to_onehot(gnd_lab).T
    y[idx, :] = 0
    result = y
    for prop in range(num_propagations):
        result = (1 - alpha) * y + alpha * adjnorm @ (result ** p)
        result[train_mask, :] = y[train_mask, :]
    result = np.clip(result, 0, 1)

    result = np.argmax(result, axis=1)
    # print(result.shape)

    gnd_lab = gnd_lab[mask_test]
    result = result[mask_test]

    accuracy_lp = accuracy_score(gnd_lab, result, normalize=True)
    f1_sc_lp = f1_score(gnd_lab, result, zero_division=1, average='micro')
    precision_lp = precision_score(gnd_lab, result, average='micro')
    print(' acc_LP: {}'.format(accuracy_lp))

    return accuracy_lp, f1_sc_lp, precision_lp


def SGC(feats, adjacency, pow, mask_test, mask_train, gnd_lab):
    print(f' --- SGC Labels --- ')

    filter = preprocess_adj(adjacency)
    filter = (1 - 0.5) * sp.eye(filter.shape[0]) + 0.5 * filter

    # clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(16,), random_state=0)
    clf = MLPClassifier(solver='adam', alpha=1e-3, hidden_layer_sizes=(16,), activation='logistic',
                        random_state=0, learning_rate_init=0.01)
    for P in range(pow):
        feats = filter.dot(feats)

    if isinstance(feats, np.matrix):
        feats = np.array(feats)

    X_train = feats[mask_train, :]
    X_test = feats[mask_test, :]
    y_train = gnd_lab[mask_train]
    y_test = gnd_lab[mask_test]

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    accuracy_sgc = accuracy_score(y_test, y_pred, normalize=True)
    f1_sc_sgc = f1_score(y_test, y_pred, zero_division=1, average='micro')
    precision_sgc = precision_score(y_test, y_pred, average='micro')
    loss_sgc = clf.loss_

    print('loss_SGC : {}, acc_SGC: {}'.format(loss_sgc, accuracy_sgc))
    return accuracy_sgc, f1_sc_sgc, precision_sgc, loss_sgc


def lp_agc(kmeans_lp_agc, predict_labels, u, adjacency, gnd_lab, train_mask, mask_test,
           lp_prop_num=50, lp_p=0.5, lp_alpha=0.5, alpha=[0.5, 0.3, 0.2], keep=2,
           knngraph=None, modify_knn=False, use_aug_graph=True):
    print(f' --- LP_AGC, power:{lp_prop_num}, alphas: {alpha}  --- ')

    if use_aug_graph:
        if knngraph is not None:
            adjnormal = knngraph
            if modify_knn:
                adjnormal = ProcessKNN(mask_test, knngraph)

        else:
            if keep == 1:
                # build sparse adj
                adjnormal = agc_labels_to_optimized_graph(predict_labels, adjacency.shape)

            else:
                # taking top k for adj norm
                print('modify for selecting more than 1 smallest')
                exit(0)
                adjnormal_old = distance_to_centroids(kmeans_sgc_agc, u, predict_labels, dist_type=1, keep=keep)
    else:
        adjnormal = None

    # preparing for LP
    idx = mask_test
    y = to_onehot(gnd_lab).T
    y[idx, :] = 0

    # print(mask_test[mask_test==True].shape)
    # exit(0)

    # adjnormal = preprocess_adj((adjnormal).dot(adjnormal.T))
    if adjnormal is not None:
        adjnormal = preprocess_adj(adjnormal)
    else:
        adjnormal = 0
    adjnormal2 = preprocess_adj(y.dot(y.T) + sp.eye(adjacency.shape[0]))
    adjnormal3 = preprocess_adj(adjacency)

    adjnormal = alpha[0] * adjnormal + alpha[1] * adjnormal2 + alpha[2] * adjnormal3

    # LP
    final = y

    for prop in range(lp_prop_num):
        # print("prop for label is: ", prop)
        final = (1 - lp_alpha) * y + lp_alpha * adjnormal @ (final**lp_p)
        final[train_mask, :] = y[train_mask, :]
    final = np.clip(final, 0, 1)

    final = np.argmax(final, axis=1)
    # print(final.shape)
    tmp_out = final

    gnd_lab = gnd_lab[mask_test]
    final = final[mask_test]

    accuracy_lp_agc = accuracy_score(gnd_lab, final, normalize=True)
    f1_sc_lp_agc = f1_score(gnd_lab, final, zero_division=1, average='micro')
    precision_lp_agc = precision_score(gnd_lab, final, average='micro')
    print(' acc_LP_AGC: {}'.format(accuracy_lp_agc))
    # print(gnd_lab, min(gnd_lab))
    # print(final.shape, final, final.min())
    # print(len(list(final)), list(final))
    # exit(0)
    return accuracy_lp_agc, f1_sc_lp_agc, precision_lp_agc, tmp_out


def sgc_agc(kmeans_sgc_agc, predict_labels, u, feats, adjacency, gnd_lab, mask_test,
            mask_train, prop_num=50, alpha=[0.5, 0.2, 0.3], keep=2, message='only',
            knngraph=None, modify_knn=False, use_aug_graph=True):

    # print(len(predict_labels), predict_labels.max(), len(np.unique(predict_labels)))
    # ll = len(np.unique(predict_labels))
    # feats_sim = to_onehot(predict_labels)
    # feats_sim = predict_labels.dot(predict_labels.T)
    # exit(0)

    print(f' --- SGC_AGC_{message}, power:{prop_num}, alphas: {alpha}  --- ')
    if use_aug_graph:
        if knngraph is not None:
            adjnormal = knngraph
            if modify_knn:
                adjnormal = ProcessKNN(mask_test, knngraph)
        else:
            if keep == 1:
                # build sparse adj
                adjnormal = agc_labels_to_optimized_graph(predict_labels, adjacency.shape)

            else:
                # taking top k for adj norm
                print('modify for selecting more than 1 smallest')
                exit(0)
                adjnormal_old = distance_to_centroids(kmeans_sgc_agc, u, predict_labels, dist_type=1, keep=keep)

    else:
        adjnormal = None

    # preparing for adj
    idx = mask_test
    y = to_onehot(gnd_lab).T
    y[idx, :] = 0

    if adjnormal is not None:
        adjnormal = preprocess_adj(adjnormal)
    else:
        adjnormal = 0
    adjnormal2 = preprocess_adj(y.dot(y.T) + sp.eye(adjacency.shape[0]))
    adjnormal3 = preprocess_adj(adjacency)
    adjnormal = alpha[0] * adjnormal + alpha[1] * adjnormal2 + alpha[2] * adjnormal3
    adjnormal = (1 - 0.5) * sp.eye(feats.shape[0]) + 0.5 * adjnormal

    # clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(16,), random_state=0)
    clf = MLPClassifier(solver='adam', alpha=1e-3, hidden_layer_sizes=(16,), activation='logistic',
                        random_state=0, learning_rate_init=0.01)
    for P in range(prop_num):
        feats = adjnormal.dot(feats)

    if isinstance(feats, np.matrix):
        feats = np.array(feats)

    X_train = feats[mask_train, :]
    X_test = feats[mask_test, :]
    y_train = gnd_lab[mask_train]
    y_test = gnd_lab[mask_test]

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    accuracy_sgc_agc = accuracy_score(y_test, y_pred, normalize=True)
    f1_sc_sgc_agc = f1_score(y_test, y_pred, zero_division=1, average='micro')
    precision_sgc_agc = precision_score(y_test, y_pred, average='micro')
    loss_sgc_agc = clf.loss_

    print('loss_SGC_AGC_{} : {}, acc_SGC_AGC_{}: {}'.format(message, loss_sgc_agc,
                                                            message, accuracy_sgc_agc))
    return accuracy_sgc_agc, f1_sc_sgc_agc, precision_sgc_agc, loss_sgc_agc


def kmeans_clustering(feats_kmeans_, gnd_lab,
                      k=7, use_svd=False, num_NN=20):
    print("==========================KMEANS==============================")

    if use_svd:
        u, s, v = sp.linalg.svds(feats_kmeans_, k=k, which='LM')

    else: 
        u = feats_kmeans_
    
    kmeans_mod = KMeans(n_clusters=k, random_state=0, init='k-means++', n_init=10).fit(u)
    predict_labels = kmeans_mod.predict(u)

    Knn_G = kneighbors_graph(u, num_NN, mode='connectivity', include_self=True).toarray()
    # plt.imshow(Knn_G)
    # plt.show()
    
    try:
        cm = clustering_metrics(gnd_lab, predict_labels)
        accuracy_, nmi_, f1_sc_ = cm.evaluationClusterModelFromLabel()
    except TypeError:
        accuracy_, nmi_, f1_sc_ = 0, 0, 0
    print(f'--- Kmeans, number of non empty clusters: {len(np.unique(predict_labels))}  ---')
    
    return accuracy_, nmi_, f1_sc_, predict_labels, kmeans_mod, u, Knn_G


def run_DGI(dataset, mask_train, mask_test, gnd_lab):
    feats = "../DGI/results/best_dgi_embed_" + dataset + "0.npz"
    print(f'DGI: loading: {feats}')
    feats = np.load(feats)['x'].squeeze()
    # print(feats.shape)
    # exit(0)

    # clf = LogisticRegression(random_state=0)
    clf = MLPClassifier(solver='adam', alpha=1e-3, hidden_layer_sizes=(16,), activation='logistic',
                        random_state=0, learning_rate_init=0.01)

    X_train = feats[mask_train, :]
    X_test = feats[mask_test, :]
    y_train = gnd_lab[mask_train]
    y_test = gnd_lab[mask_test]

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    accuracy_DGI = accuracy_score(y_test, y_pred, normalize=True)
    f1_sc_DGI = f1_score(y_test, y_pred, zero_division=1, average='micro')
    precision_DGI = precision_score(y_test, y_pred, average='micro')
    # loss_DGI = clf.loss_

    print('acc_DGI: {}'.format(accuracy_DGI))
    return accuracy_DGI, f1_sc_DGI, precision_DGI


def run_GMI(dataset, mask_train, mask_test, gnd_lab):
    feats = "../GMI/results/best_gmi_embed_" + dataset + "0.npz"
    print(f'GMI: loading: {feats}')
    feats = np.load(feats)['x'].squeeze()

    # clf = LogisticRegression(random_state=0)
    clf = MLPClassifier(solver='adam', alpha=1e-3, hidden_layer_sizes=(16,), activation='logistic',
                        random_state=0, learning_rate_init=0.01)

    X_train = feats[mask_train, :]
    X_test = feats[mask_test, :]
    y_train = gnd_lab[mask_train]
    y_test = gnd_lab[mask_test]

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    accuracy_GMI = accuracy_score(y_test, y_pred, normalize=True)
    f1_sc_GMI = f1_score(y_test, y_pred, zero_division=1, average='micro')
    precision_GMI = precision_score(y_test, y_pred, average='micro')
    # loss_GMI = clf.loss_

    print('acc_GMI: {}'.format(accuracy_GMI))
    return accuracy_GMI, f1_sc_GMI, precision_GMI
