import sys
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import KMeans
from sklearn.preprocessing import LabelEncoder
import numpy as np
import torch
import torch.distributions as dist


class Logger(object):
    def __init__(self, filename, mode="a"):
        self.terminal = sys.stdout
        self.log = open(filename, mode)

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass


def save_model_light(model, filepath):
    torch.save(model.state_dict(), filepath)


def cluster_acc(Y_pred, Y):
    assert Y_pred.size == Y.size
    D = max(Y_pred.max(), Y.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(Y_pred.size):
        w[Y_pred[i], Y[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    return w[row_ind, col_ind].sum() / Y_pred.size, w


def purity(labels_true, labels_pred):
    clusters = np.unique(labels_pred)
    labels_true = np.reshape(labels_true, (-1, 1))
    labels_pred = np.reshape(labels_pred, (-1, 1))
    count = []
    for c in clusters:
        idx = np.where(labels_pred == c)[0]
        labels_tmp = labels_true[idx, :].reshape(-1)
        count.append(np.bincount(labels_tmp).max())
    return np.sum(count) / labels_true.shape[0]


def clustering(dl, model, n):
    batch_a_z, batch_b_z = [], []
    batch_a_w, batch_b_w = [], []
    batch_a_u, batch_b_u = [], []
    labels = []
    accuracies = {}
    with torch.no_grad():
        for i, (batch_a, batch_b, y) in enumerate(dl):
            batch_a, batch_b = batch_a.to(model.device), batch_b.to(model.device)
            au = dist.Normal(*model.a_to_z(batch_a)).sample()
            az, aw = torch.split(au, [model.z_dim, model.w_dim], dim=-1)
            batch_a_u.append(au.cpu())
            batch_a_z.append(az.cpu())
            batch_a_w.append(aw.cpu())
            bu = dist.Normal(*model.b_to_z(batch_b)).sample()
            bz, bw = torch.split(bu, [model.z_dim, model.w_dim], dim=-1)
            batch_b_u.append(bu.cpu())
            batch_b_z.append(bz.cpu())
            batch_b_w.append(bw.cpu())
            labels.append(y)

    batch_a_u = torch.cat(batch_a_u, dim=0).data.numpy()
    batch_a_z = torch.cat(batch_a_z, dim=0).data.numpy()
    batch_a_w = torch.cat(batch_a_w, dim=0).data.numpy()
    batch_b_u = torch.cat(batch_b_u, dim=0).data.numpy()
    batch_b_z = torch.cat(batch_b_z, dim=0).data.numpy()
    batch_b_w = torch.cat(batch_b_w, dim=0).data.numpy()
    all_batch_a_z = torch.cat([torch.Tensor(batch_a_z), torch.Tensor(batch_b_z)], dim=1)
    labels = torch.cat(labels, dim=0).data.numpy()
    y = LabelEncoder().fit_transform(labels)

    kmeans = KMeans(n_clusters=n, random_state=0).fit(batch_a_u)
    cls_index = kmeans.labels_
    acc, _ = cluster_acc(cls_index, y)
    NMI = normalized_mutual_info_score(y, cls_index)
    ARI = adjusted_rand_score(y, cls_index)
    pur = purity(y, cls_index)
    accuracies['batch_a_u'] = [acc, NMI, ARI, pur]

    kmeans = KMeans(n_clusters=n, random_state=0).fit(batch_a_z)
    cls_index = kmeans.labels_
    acc, _ = cluster_acc(cls_index, y)
    NMI = normalized_mutual_info_score(y, cls_index)
    ARI = adjusted_rand_score(y, cls_index)
    pur = purity(y, cls_index)
    accuracies['batch_a_z'] = [acc, NMI, ARI, pur]

    kmeans = KMeans(n_clusters=n, random_state=0).fit(batch_a_w)
    cls_index = kmeans.labels_
    acc, _ = cluster_acc(cls_index, y)
    NMI = normalized_mutual_info_score(y, cls_index)
    ARI = adjusted_rand_score(y, cls_index)
    pur = purity(y, cls_index)
    accuracies['batch_a_w'] = [acc, NMI, ARI, pur]

    kmeans = KMeans(n_clusters=n, random_state=0).fit(batch_b_u)
    cls_index = kmeans.labels_
    acc, _ = cluster_acc(cls_index, y)
    NMI = normalized_mutual_info_score(y, cls_index)
    ARI = adjusted_rand_score(y, cls_index)
    pur = purity(y, cls_index)
    accuracies['batch_b_u'] = [acc, NMI, ARI, pur]

    kmeans = KMeans(n_clusters=n, random_state=0).fit(batch_b_z)
    cls_index = kmeans.labels_
    acc, _ = cluster_acc(cls_index, y)
    NMI = normalized_mutual_info_score(y, cls_index)
    ARI = adjusted_rand_score(y, cls_index)
    pur = purity(y, cls_index)
    accuracies['batch_b_z'] = [acc, NMI, ARI, pur]

    kmeans = KMeans(n_clusters=n, random_state=0).fit(batch_b_w)
    cls_index = kmeans.labels_
    acc, _ = cluster_acc(cls_index, y)
    NMI = normalized_mutual_info_score(y, cls_index)
    ARI = adjusted_rand_score(y, cls_index)
    pur = purity(y, cls_index)
    accuracies['batch_b_w'] = [acc, NMI, ARI, pur]

    kmeans = KMeans(n_clusters=n, random_state=0).fit(all_batch_a_z)
    cls_index = kmeans.labels_
    acc, _ = cluster_acc(cls_index, y)
    NMI = normalized_mutual_info_score(y, cls_index)
    ARI = adjusted_rand_score(y, cls_index)
    pur = purity(y, cls_index)
    accuracies['all_batch_a_z'] = [acc, NMI, ARI, pur]

    return accuracies


