'''
This script defines the clustering evaluation functions (which are not presented in the paper).
'''

import numpy as np
from sklearn.cluster import KMeans
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score


def fit_kmeans(features, y, MAX_SAMPLES=10000):
    nb_classes = np.unique(y, return_counts=True)[1].shape[0]
    train_size = features.shape[0]

    kmeans = KMeans(n_clusters=nb_classes, n_init='auto')
    if train_size // nb_classes < 5 or train_size < 50:
        return kmeans.fit(features)
    else:
        grid_search = GridSearchCV(
            kmeans, {
                'n_clusters': [nb_classes],
                'init': ['k-means++', 'random'],
                'algorithm': ['lloyd', 'elkan'],
                'max_iter': [1000],
                'tol': [1e-4],
                'random_state': [131]
            },
            cv=5, n_jobs=5
        )
        # If the training set is too large, subsample MAX_SAMPLES examples
        if train_size > MAX_SAMPLES:
            features, _, y, _ = train_test_split(features, y, train_size=MAX_SAMPLES, stratify=y, random_state=131)
        
        grid_search.fit(features, y)
        return grid_search.best_estimator_


def eval_clustering(model, train_data, train_labels, test_data, test_labels):
    assert train_labels.ndim == 1 or train_labels.ndim == 2

    train_repr = model.encode(train_data, **model.encode_args).detach().cpu().numpy()
    test_repr = model.encode(test_data, **model.encode_args).detach().cpu().numpy()

    def merge_dim01(array):
        return array.reshape(array.shape[0]*array.shape[1], *array.shape[2:])

    if train_labels.ndim == 2:
        train_repr = merge_dim01(train_repr)
        train_labels = merge_dim01(train_labels)
        test_repr = merge_dim01(test_repr)
        test_labels = merge_dim01(test_labels)
    
    kmeans = fit_kmeans(train_repr, train_labels)
    test_pred = kmeans.predict(test_repr)
    
    test_ari = adjusted_rand_score(test_labels, test_pred)
    test_ami = adjusted_mutual_info_score(test_labels, test_pred)

    return test_ari, test_ami