import torch
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader
from sklearn.metrics import normalized_mutual_info_score

def DBILoss(output, num_class, alpha = 50.):
        device = output.device
        # cluster_ids_x, cluster_centers = kmeans(
        #     X=output, num_clusters=num_class, tol = 1e-8, distance='euclidean', device=device
        # )

        kmeans = KMeans(n_clusters=num_class).fit(output.detach().cpu().numpy())
        cluster_ids_x = torch.tensor(kmeans.labels_)
        cluster_centers = torch.tensor(kmeans.cluster_centers_)

        cluster_ids_x = cluster_ids_x.to(device)
        cluster_centers = cluster_centers.to(device)
        radius = []
        for i in range(num_class):
            dist = (output - cluster_centers[i]).square().sum(-1).sqrt()
            dist = dist * (1 - (cluster_ids_x == i).type(torch.float32))
            radius.append(dist.mean())
        centdist = torch.cdist(cluster_centers, cluster_centers)
        tmp = torch.tensor(radius).unsqueeze(0).to(device)
        centdist = (tmp[..., None] + tmp[:, None, :]) + centdist
        return (torch.mul(centdist, torch.exp(alpha * centdist)).sum(-1) / (torch.exp(alpha * centdist).sum(-1))).mean()

def get_dist_unsuper(dataset, model, device):
    loader = DataLoader(dataset, 32, shuffle=True)
    dists = []
    gts = []
    with torch.no_grad():
        for data, label in loader:
            data = data.unsqueeze(1).to(device)
            gts.append(label)
            out, knorm = model(data)
            dists.append(out.min(-1)[0])
    dists = torch.vstack(dists)
    gts = torch.concat(gts)
    return dists, gts  

def NMI(model, num_class, train_dataset, test_dataset, device):
        train_dists, train_true = get_dist_unsuper(train_dataset, model, device)
        test_dists, test_true = get_dist_unsuper(test_dataset, model, device)
        from sklearn.cluster import KMeans
        kmeans = KMeans(n_clusters=num_class).fit(train_dists.detach().cpu().numpy())
        train_pred = kmeans.predict(train_dists.detach().cpu().numpy())
        test_pred = kmeans.predict(test_dists.detach().cpu().numpy())
        return normalized_mutual_info_score(train_true, train_pred), normalized_mutual_info_score(test_true, test_pred)