from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, accuracy_score
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader
import numpy as np
import torch

def cluster_acc(y_true, y_pred):
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    u = linear_sum_assignment(w.max() - w)
    ind = np.concatenate([u[0].reshape(u[0].shape[0], 1), u[1].reshape([u[0].shape[0], 1])], axis=1)
    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size


def purity(y_true, y_pred):
    y_voted_labels = np.zeros(y_true.shape)
    labels = np.unique(y_true)
    ordered_labels = np.arange(labels.shape[0])
    for k in range(labels.shape[0]):
        y_true[y_true == labels[k]] = ordered_labels[k]
    labels = np.unique(y_true)
    bins = np.concatenate((labels, [np.max(labels)+1]), axis=0)

    for cluster in np.unique(y_pred):
        hist, _ = np.histogram(y_true[y_pred == cluster], bins=bins)
        winner = np.argmax(hist)
        y_voted_labels[y_pred == cluster] = winner

    return accuracy_score(y_true, y_voted_labels)


def evaluate(label, pred):
    nmi = normalized_mutual_info_score(label, pred)
    ari = adjusted_rand_score(label, pred)
    acc = cluster_acc(label, pred)
    pur = purity(label, pred)
    return nmi, ari, acc, pur

def activate_and_normalize(tensor):
    tensor = torch.clamp(tensor, min=0)
    row_sums = tensor.sum(dim=1, keepdim=True)
    row_sums = row_sums + (row_sums == 0).float()
    tensor = tensor / row_sums
    return tensor

def inference(loader, model, device, view, data_size,
              class_num, max_view):
    Hs = []
    Zs = []
    labels_vector_multi = []
    for v in range(view):
        Hs.append([])
        Zs.append([])
        labels_vector_multi.append([])
    labels_vector = []
    for step, (xs, y, _) in enumerate(loader):
        for v in range(view):
            xs[v] = xs[v].to(device)
        with torch.no_grad():
            hs, _, zs, zs_pre, zs_pre_align, hs_align = model.forward(xs, max_view)
        for v in range(view):
            hs[v] = hs[v].detach()
            zs[v] = zs_pre_align[v].detach()
            Hs[v].extend(hs[v].cpu().detach().numpy())
            Zs[v].extend(zs[v].cpu().detach().numpy())
            labels_vector_multi[v].extend(y[v].numpy())
        labels_vector.extend(y[max_view].numpy())

    labels_vector = np.array(labels_vector).reshape(data_size)
    H_avg = np.array(Hs[max_view])

    kmeans = KMeans(n_clusters=class_num, n_init=100)
    total_pred_h = kmeans.fit_predict(H_avg)

    return total_pred_h, labels_vector, H_avg

def valid(model, device, dataset, view,
          data_size, class_num, max_view, epoch, ts):

    test_loader = DataLoader(
        dataset,
        batch_size=256,
        shuffle=True,
    )

    total_pred_h, labels_vector, H_avg = inference(
        test_loader, model, device, view, data_size,
        class_num, max_view)
    print("Clustering results on H: " + str(labels_vector.shape[0]))
    nmi, ari, acc, pur = evaluate(labels_vector, total_pred_h)
    print('ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur))

    return acc, nmi, pur, ari
