import warnings
import numpy as np
from munkres import Munkres
from collections import Counter
from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, normalized_mutual_info_score, adjusted_rand_score
import torch
import torch.nn as nn
import torch.nn.modules.loss

from unsup_model import LogReg
warnings.filterwarnings("ignore")



def unsupervised_test_linear(data, embeds, n_classes, device, args):
    train_embs = embeds[data.train_mask]
    val_embs = embeds[data.val_mask]
    test_embs = embeds[data.test_mask]

    train_labels = data.y[data.train_mask]
    val_labels = data.y[data.val_mask]
    test_labels = data.y[data.test_mask]

    best_val_acc = 0
    eval_acc = 0
    bad_counter = 0

    logreg = LogReg(hid_dim=train_embs.size(1), n_classes=n_classes).to(device)
    opt = torch.optim.Adam(logreg.parameters(), lr=args.lr2, weight_decay=args.wd2)

    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(2000):
        logreg.train()
        opt.zero_grad()
        logits = logreg(train_embs)
        preds = torch.argmax(logits, dim=1)
        train_acc = torch.sum(preds == train_labels).float() / train_labels.shape[0]
        loss = loss_fn(logits, train_labels)
        loss.backward()
        opt.step()

        logreg.eval()
        with torch.no_grad():
            val_logits = logreg(val_embs)
            test_logits = logreg(test_embs)

            val_preds = torch.argmax(val_logits, dim=1)
            test_preds = torch.argmax(test_logits, dim=1)

            val_acc = torch.sum(val_preds == val_labels).float() / val_labels.shape[0]
            test_acc = torch.sum(test_preds == test_labels).float() / test_labels.shape[0]

            if val_acc >= best_val_acc:
                bad_counter = 0
                best_val_acc = val_acc
                if test_acc > eval_acc:
                    eval_acc = test_acc
            else:
                bad_counter += 1
    return eval_acc.cpu().data.item()



def cluster_eval(y_true, y_pred):
    """code source: https://github.com/bdy9527/SDCN"""
    y_true = y_true.detach().cpu().numpy() if type(y_true) is torch.Tensor else y_true
    y_pred = y_pred.detach().cpu().numpy() if type(y_pred) is torch.Tensor else y_pred

    l1 = list(set(y_true))
    numclass1 = len(l1)
    l2 = list(set(y_pred))
    numclass2 = len(l2)

    ind = 0
    c2 = Counter(y_pred)
    maxclass = sorted(c2.items(), key=lambda item: item[1], reverse=True)[0][0]
    if numclass1 != numclass2:
        for i in l1:
            if i in l2:
                pass
            else:
                ind = y_pred.tolist().index(maxclass)
                y_pred[ind] = i

    l2 = list(set(y_pred))
    numclass2 = len(l2)

    cost = np.zeros((numclass1, numclass2), dtype=int)
    for i, c1 in enumerate(l1):
        mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1]
        for j, c2 in enumerate(l2):
            mps_d = [i1 for i1 in mps if y_pred[i1] == c2]
            cost[i][j] = len(mps_d)
   
    m = Munkres()   # match two clustering results by Munkres algorithm
    cost = cost.__neg__().tolist()
    indexes = m.compute(cost)

    new_predict = np.zeros(len(y_pred)) # get the match results
    for i, c in enumerate(l1):
        c2 = l2[indexes[i][1]]  # correponding label in l2:
        ai = [ind for ind, elm in enumerate(y_pred) if elm == c2]  # # ai is the index with label==c2 in the pred_label list
        new_predict[ai] = c

    acc = accuracy_score(y_true, new_predict)
    f1_macro = f1_score(y_true, new_predict, average='macro')
    return acc, f1_macro


def unsup_test_cluster(y_true, y_pred, verbose=False):
    acc, f1 = cluster_eval(y_true, y_pred)
    nmi = normalized_mutual_info_score(y_true, y_pred, average_method='arithmetic')
    ari = adjusted_rand_score(y_true, y_pred)
    if verbose:
        print(f'acc {acc:.4f}, nmi {nmi:.4f}, ari {ari:.4f}, f1 {f1:.4f}')
    return acc, nmi, ari, f1


def unsupervised_test_kmeans(X, y, n_clusters, repeat=10):
    y = y.detach().cpu().numpy() if type(y) is torch.Tensor else y
    X = X.detach().cpu().numpy() if type(X) is torch.Tensor else X

    mask_nan = np.isnan(X)
    mask_inf = np.isinf(X)
    X[mask_nan] = 1
    X[mask_inf] = 1

    res_list = {'acc_list': [], 'nmi_list': [], 'ari_list': [], 'f1_list': []}
    for _ in range(repeat):
        kmeans = KMeans(n_clusters=n_clusters)
        y_pred = kmeans.fit_predict(X)
        acc_score, nmi_score, ari_score, macro_f1 = unsup_test_cluster(y_true=y, y_pred=y_pred, verbose=False)
        res_list['acc_list'].append(acc_score)
        res_list['nmi_list'].append(nmi_score)
        res_list['ari_list'].append(ari_score)
        res_list['f1_list'].append(macro_f1)
    res = {}
    res['acc_mean'], res['acc_std'] = np.mean(res_list['acc_list']), np.std(res_list['acc_list'])
    res['nmi_mean'], res['nmi_std'] = np.mean(res_list['nmi_list']), np.std(res_list['nmi_list'])
    res['ari_mean'], res['ari_std'] = np.mean(res_list['ari_list']), np.std(res_list['ari_list'])
    res['f1_mean'], res['f1_std'] = np.mean(res_list['f1_list']), np.std(res_list['f1_list'])
    return res


def unsupervised_test_knn(data, embeds, n_classes, device, args):
    train_embs = embeds[data.train_mask].detach().cpu().numpy()
    val_embs = embeds[data.val_mask].detach().cpu().numpy()
    test_embs = embeds[data.test_mask].detach().cpu().numpy()

    train_labels = data.y[data.train_mask].detach().cpu().numpy()
    val_labels = data.y[data.val_mask].detach().cpu().numpy()
    test_labels = data.y[data.test_mask].detach().cpu().numpy()

    neigh = KNeighborsClassifier(n_neighbors=args.knn_k)
    neigh.fit(train_embs, train_labels)
    preds = neigh.predict(test_embs)
    acc = accuracy_score(test_labels, preds)
    return acc
