from time import time
from sklearn import metrics
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

import torch

def calculate_centroids(x, prob):
    # compute position of centroids
    return prob.T.matmul(x) / prob.sum(dim=0).reshape(-1,1) 

def kmeans_loss(x, prob, n_classes):
    centroids = calculate_centroids(x, prob)
    loss = torch.cat(
        [(x - centroids[i]).norm(dim=1).dot(prob[:,i]).view(1,-1) for i in range(n_classes)],
        dim=0
    )
    return torch.mean(loss)

def cosine_similarity_loss(x, prob, n_classes):
    centroids = calculate_centroids(x, prob)
    x = x / x.norm(dim=1)[:, None]
    centroids = centroids / centroids.norm(dim=1)[:, None]
    loss = torch.matmul(x, centroids.T)
    return torch.mean(loss)

def bench_clustering(name, data, predictions, labels, loss, execution_time):
    """Benchmark to evaluate the clustering methods.

    Parameters
    ----------
    name : str
        Name given to the strategy. It will be used to show the results in a
        table.
    data : ndarray of shape (n_samples, n_features)
        The data to cluster.
    predictions : ndarray of shape (n_samples,)
        The predicted labels used to compute the clustering metrics.
    labels : ndarray of shape (n_samples,)
        The labels used to compute the clustering metrics which requires some
        supervision.
    """
    results = [name]

    # Define the metrics which require only the true labels and estimator
    # labels
    clustering_metrics = [
        metrics.homogeneity_score,
        metrics.completeness_score,
        metrics.v_measure_score,
        metrics.adjusted_rand_score,
        metrics.adjusted_mutual_info_score,
    ]
    results += [m(labels, predictions) for m in clustering_metrics]

    # The silhouette score requires the full dataset
    '''results += [
        metrics.silhouette_score(
            data,
            predictions,
            metric="euclidean",
            sample_size=300,
        )
    ]'''
    
    # loss and time
    results.append(loss)
    results.append(execution_time)
    
    # Show the results
    formatter_result = (
        "{:4s}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}"
    )
    print(formatter_result.format(*results))
    
def calculate_matmul_n_times(n_components, mat_a, mat_b):
    """
    Calculate matrix product of two matrics with mat_a[0] >= mat_b[0].
    Bypasses torch.matmul to reduce memory footprint.
    args:
        mat_a:      torch.Tensor (n, k, 1, d)
        mat_b:      torch.Tensor (1, k, d, d)
    """
    res = torch.zeros(mat_a.shape).to(mat_a.device)
    
    for i in range(n_components):
        mat_a_i = mat_a[:, i, :, :].squeeze(-2)
        mat_b_i = mat_b[0, i, :, :].squeeze()
        res[:, i, :, :] = mat_a_i.mm(mat_b_i).unsqueeze(1)
    
    return res


def calculate_matmul(mat_a, mat_b):
    """
    Calculate matrix product of two matrics with mat_a[0] >= mat_b[0].
    Bypasses torch.matmul to reduce memory footprint.
    args:
        mat_a:      torch.Tensor (n, k, 1, d)
        mat_b:      torch.Tensor (n, k, d, 1)
    """
    assert mat_a.shape[-2] == 1 and mat_b.shape[-1] == 1
    return torch.sum(mat_a.squeeze(-2) * mat_b.squeeze(-1), dim=2, keepdim=True)