"""For GEORGE, define cluster."""

# https://github.com/HazyResearch/hidden-stratification/blob/master/stratification/cluster/models/cluster.py

try:
    from libKMCUDA import kmeans_cuda

    _LIBKMCUDA_FOUND = True
except ModuleNotFoundError:
    _LIBKMCUDA_FOUND = False

from functools import partial
import functools
import logging
import numpy as np
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from collections import Counter


__all__ = [
    "KMeans",
    "GaussianMixture",
    "FastKMeans",
    "AutoKMixtureModel",
    "OverclusterModel",
    "DummyClusterer",
    "get_k_from_model",
]


def get_cluster_sils(data, pred_labels, compute_sil=True, cuda=False):
    unique_preds = sorted(np.unique(pred_labels))
    SIL_samples = (
        silhouette_samples(data, pred_labels, cuda=cuda)
        if compute_sil
        else np.zeros(len(data))
    )
    SILs_by_cluster = {
        int(label): float(np.mean(SIL_samples[pred_labels == label]))
        for label in unique_preds
    }
    SIL_global = float(np.mean(SIL_samples))
    return SILs_by_cluster, SIL_global


def compute_group_sizes(labels):
    result = dict(sorted(zip(*np.unique(labels, return_counts=True))))
    return {int(k): int(v) for k, v in result.items()}


class DummyClusterer:
    def __init__(self, **kwargs):
        self.n_components = 1

    def fit(self, X):
        return self

    def predict(self, X):
        return np.zeros(len(X), dtype=np.int32)


class FastKMeans:
    def __init__(
        self, n_clusters, random_state=0, init="k-means++", n_init=10, verbose=False
    ):
        self.k = n_clusters
        self.init = init
        if n_init > 1:
            logging.warning("n_init unsupported for GPU K-Means")
        self.seed = random_state
        self.verbose = verbose
        self.kmeans_obj = KMeans(n_clusters=n_clusters)

    def fit(self, X):
        logging.info("Using GPU-accelerated K-Means...")
        self.cluster_centers_ = kmeans_cuda(
            X.astype(np.float32), clusters=self.k, seed=self.seed, init=self.init
        )[0].astype(np.float32)
        self.kmeans_obj.cluster_centers_ = self.cluster_centers_
        if hasattr(self.kmeans_obj, "_check_params"):
            self.kmeans_obj._check_params(np.zeros_like(X))  # properly initialize
        return self.kmeans_obj

    def fit_predict(self, X):
        self.fit(X)
        return self.predict(X)

    def predict(self, X):
        return self.kmeans_obj.predict(X.astype(np.float32))

    def transform(self, X):
        return self.kmeans_obj.transform(X.astype(np.float32))


class AutoKMixtureModel:
    def __init__(
        self,
        cluster_method,
        max_k,
        n_init=3,
        seed=None,
        sil_cuda=False,
        verbose=0,
        search=True,
    ):
        if cluster_method == "kmeans":
            cluster_cls = FastKMeans if (sil_cuda and _LIBKMCUDA_FOUND) else KMeans
            k_name = "n_clusters"
        elif cluster_method == "gmm":
            cluster_cls = GaussianMixture
            k_name = "n_components"
        else:
            raise ValueError("Unsupported clustering method")

        self.cluster_cls = cluster_cls
        self.k_name = k_name
        self.search = search
        self.max_k = max_k
        self.n_init = n_init
        self.seed = seed
        self.sil_cuda = sil_cuda
        self.verbose = verbose

    def gen_inner_cluster_obj(self, k):
        # Return a clustering object according to the specified parameters
        return self.cluster_cls(
            **{self.k_name: k},
            n_init=self.n_init,
            random_state=self.seed,
            verbose=self.verbose,
        )

    def fit(self, activ):
        logger = logging.getLogger("harness.cluster")
        best_score = -2
        k_min = 2 if self.search else self.max_k
        search = self.search and k_min != self.max_k
        for k in range(k_min, self.max_k + 1):
            logger.info(f"Clustering into {k} groups...")
            cluster_obj = self.gen_inner_cluster_obj(k)
            pred_labels = cluster_obj.fit_predict(activ)
            logger.info("Clustering done, computing score...")
            cluster_sizes = compute_group_sizes(pred_labels)
            if search:
                local_sils, global_sil = get_cluster_sils(
                    activ, pred_labels, compute_sil=True, cuda=self.sil_cuda
                )
                clustering_score = np.mean(list(local_sils.values()))
                logger.info(f"k = {k} score: {clustering_score}")
                if clustering_score >= best_score:
                    logger.info(
                        f"Best model found at k = {k} with score {clustering_score:.3f}"
                    )
                    best_score = clustering_score
                    best_model = cluster_obj
                    best_k = k
            else:
                best_score, best_model, best_k = 0, cluster_obj, self.max_k

        self.best_k = best_k
        self.n_clusters = best_k
        self.best_score = best_score
        self.cluster_obj = best_model
        return self

    def predict(self, activ):
        return self.cluster_obj.predict(activ)

    def fit_predict(self, activ):
        self.fit(activ)
        return self.predict(activ)

    def predict_proba(self, activ):
        return self.cluster_obj.predict_proba(activ)

    def score(self, activ):
        return self.cluster_obj.score(activ)


class OverclusterModel:
    def __init__(
        self,
        cluster_method,
        max_k,
        oc_fac,
        n_init=3,
        search=True,
        sil_threshold=0.0,
        seed=None,
        sil_cuda=False,
        verbose=0,
        sz_threshold_pct=0.005,
        sz_threshold_abs=25,
    ):
        self.base_model = AutoKMixtureModel(
            cluster_method, max_k, n_init, seed, sil_cuda, verbose, search
        )
        self.oc_fac = oc_fac
        self.sil_threshold = sil_threshold
        self.sz_threshold_pct = sz_threshold_pct
        self.sz_threshold_abs = sz_threshold_abs
        self.requires_extra_info = True

    def get_oc_predictions(self, activ, val_activ, orig_preds, val_orig_preds):
        # Split each cluster from base_model into sub-clusters, and save each of the
        # associated sub-clustering predictors in self.cluster_objs.
        # Collate and return the new predictions in oc_preds and val_oc_preds.
        self.cluster_objs = []
        oc_preds = np.zeros(len(activ), dtype=np.int)
        val_oc_preds = np.zeros(len(val_activ), dtype=np.int)

        for i in self.pred_vals:
            sub_activ = activ[orig_preds == i]
            cluster_obj = self.base_model.gen_inner_cluster_obj(self.oc_fac).fit(
                sub_activ
            )
            self.cluster_objs.append(cluster_obj)
            sub_preds = cluster_obj.predict(sub_activ) + self.oc_fac * i
            oc_preds[orig_preds == i] = sub_preds

            val_sub_activ = val_activ[val_orig_preds == i]
            val_sub_preds = cluster_obj.predict(val_sub_activ) + self.oc_fac * i
            val_oc_preds[val_orig_preds == i] = val_sub_preds
        return oc_preds, val_oc_preds

    def filter_overclusters(self, activ, losses, orig_preds, oc_preds, val_oc_preds):
        # Keep an overcluster if its point have higher SIL than before
        # overclustering, AND it has higher average loss than the
        # original cluster, AND it contains sufficiently many training and
        # validation points.

        num_oc = np.amax(oc_preds) + 1
        # Compute original per-cluster SIL scores and losses,
        # and the SIL scores and losses after overclustering.
        orig_sample_sils = silhouette_samples(activ, orig_preds, cuda=self.sil_cuda)
        orig_losses = [np.mean(losses[orig_preds == i]) for i in self.pred_vals]
        new_sample_sils = silhouette_samples(activ, oc_preds, cuda=self.sil_cuda)

        oc_orig_sils = [np.mean(orig_sample_sils[oc_preds == i]) for i in range(num_oc)]
        oc_new_sils = [np.mean(new_sample_sils[oc_preds == i]) for i in range(num_oc)]
        new_losses = [np.mean(losses[oc_preds == i]) for i in range(num_oc)]

        # Count number of points in each cluster after overclustering. Drop tiny clusters as these
        # will lead to unreliable optimization.
        oc_counts = np.bincount(oc_preds)
        # If val clusters are too small, we will get unreliable estimates - so need to threshold these too
        val_oc_counts = np.bincount(val_oc_preds)
        tr_sz_threshold = max(len(activ) * self.sz_threshold_pct, self.sz_threshold_abs)
        val_sz_threshold = self.sz_threshold_abs

        # Decide which overclusters to keep
        oc_to_keep = []
        for i in range(num_oc):
            if (
                oc_new_sils[i] > max(oc_orig_sils[i], self.sil_threshold)
                and new_losses[i] >= orig_losses[i // self.oc_fac]
                and oc_counts[i] >= tr_sz_threshold
                and val_oc_counts[i] >= val_sz_threshold
            ):
                oc_to_keep.append(i)

        return oc_to_keep

    def create_label_map(self, num_orig_preds, oc_to_keep, oc_preds):
        # Map raw overclustering outputs to final "cluster labels," accounting for the
        # fact that some overclusters are re-merged.
        label_map = {}
        cur_cluster_ind = -1
        oc_to_base_id = {}
        for i in range(num_orig_preds):
            # For each original cluster, if there were no
            # overclusters kept within it, keep the original cluster as-is.
            # Otherwise, it needs to be split.
            keep_all = (
                True  # If we keep all overclusters, we can discard the original cluster
            )
            for j in range(self.oc_fac):
                index = i * self.oc_fac + j
                if index not in oc_to_keep:
                    keep_all = False
            if not keep_all:
                cur_cluster_ind += 1

            # Updated cluster index corresponding to original cluster
            # (points in the original cluster assigned to a non-kept overcluster
            # are merged into this cluster)
            base_index = cur_cluster_ind
            for j in range(self.oc_fac):
                index = i * self.oc_fac + j
                if index in oc_to_keep:
                    cur_cluster_ind += 1
                    oc_index = cur_cluster_ind
                else:
                    assert not keep_all
                    oc_index = base_index
                label_map[index] = oc_index
        return label_map

    def fit(self, activ, val_activ=None, losses=None):
        if val_activ is None or losses is None:
            raise ValueError("Must provide losses and val set activations")
        logger = logging.getLogger("harness.cluster")
        logger.info("Fitting base model...")
        orig_preds = self.base_model.fit_predict(activ)
        self.pred_vals = sorted(np.unique(orig_preds))
        num_orig_preds = len(self.pred_vals)
        losses = np.array(losses)
        oc_fac = self.oc_fac
        num_oc = num_orig_preds * oc_fac
        val_orig_preds = self.base_model.predict(val_activ)

        logger.info("Fitting overclustering model...")
        oc_preds, val_oc_preds = self.get_oc_predictions(
            activ, val_activ, orig_preds, val_orig_preds
        )
        oc_to_keep = self.filter_overclusters(
            activ, losses, orig_preds, oc_preds, val_oc_preds
        )
        self.label_map = self.create_label_map(num_orig_preds, oc_to_keep, oc_preds)

        new_preds = np.zeros(len(activ), dtype=np.int)
        for i in range(num_oc):
            new_preds[oc_preds == i] = self.label_map[i]

        self.n_clusters = (
            max(self.label_map.values()) + 1
        )  # Final number of output predictions
        logger.info(f"Final number of clusters: {self.n_clusters}")
        return self

    def predict(self, activ):
        # Get clusters from base model
        base_preds = self.base_model.predict(activ)
        # Get overclusters
        oc_preds = np.zeros(len(activ), dtype=np.int)
        for i in self.pred_vals:
            subfeats = activ[base_preds == i]
            subpreds = self.cluster_objs[i].predict(subfeats) + self.oc_fac * i
            oc_preds[base_preds == i] = subpreds

        # Merge overclusters appropriately and return final predictions
        new_preds = np.zeros(len(activ), dtype=np.int)
        for i in range(len(self.pred_vals) * self.oc_fac):
            new_preds[oc_preds == i] = self.label_map[i]
        return new_preds

    @property
    def sil_cuda(self):
        return self.base_model.sil_cuda

    @property
    def n_init(self):
        return self.base_model.n_init

    @property
    def seed(self):
        return self.base_model.seed


##############################################################
# https://github.com/HazyResearch/hidden-stratification/blob/master/stratification/cluster/utils.py
def get_k_from_model(model):
    if hasattr(model, "n_clusters"):
        return model.n_clusters
    elif hasattr(model, "n_components"):
        return model.n_components
    else:
        raise NotImplementedError(
            f"model {type(model)} K not found."
            + f"model attributes:\n{list(model.__dict__.keys())}"
        )


def get_cluster_mean_loss(sample_losses, assignments):
    cluster_losses = {}

    C = np.unique(assignments)
    for c in C:
        cluster_loss = np.mean(sample_losses[assignments == c])
        cluster_losses[str(c)] = float(cluster_loss)
    return cluster_losses


def get_cluster_composition(superclasses, assignments):
    compositions = {}

    S = np.unique(superclasses)
    C = np.unique(assignments)
    for c in C:
        superclasses_c = superclasses[assignments == c]
        counts = dict(Counter(superclasses_c))
        compositions[str(c)] = {str(s): counts.get(s, 0) for s in S}
    return compositions


# ##############################################################
# """The functions in this file are adapted from scikit-learn
# (https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/cluster/_unsupervised.py)
# to use CUDA for Silhouette score computation."""

# import numpy as np
# from sklearn.utils import gen_batches, get_chunk_n_rows
# from sklearn.metrics.cluster._unsupervised import *
# from sklearn.metrics import silhouette_samples as s_sil
# import torch


# def silhouette_samples(X, labels, verbose=False, cuda=False):
#     if not cuda:
#         return s_sil(X, labels)
#     X, labels = check_X_y(X, labels, accept_sparse=["csc", "csr"])

#     le = LabelEncoder()
#     labels = le.fit_transform(labels)
#     n_samples = len(labels)
#     label_freqs = np.bincount(labels)
#     check_number_of_labels(len(le.classes_), n_samples)

#     reduce_func = functools.partial(
#         _silhouette_reduce, labels=labels, label_freqs=label_freqs
#     )
#     results = zip(
#         *pairwise_distances_chunked_cuda(X, reduce_func=reduce_func, verbose=verbose)
#     )
#     intra_clust_dists, inter_clust_dists = results
#     intra_clust_dists = np.concatenate(intra_clust_dists)
#     inter_clust_dists = np.concatenate(inter_clust_dists)

#     denom = (label_freqs - 1).take(labels, mode="clip")
#     with np.errstate(divide="ignore", invalid="ignore"):
#         intra_clust_dists /= denom

#     sil_samples = inter_clust_dists - intra_clust_dists
#     with np.errstate(divide="ignore", invalid="ignore"):
#         sil_samples /= np.maximum(intra_clust_dists, inter_clust_dists)
#     # nan values are for clusters of size 1, and should be 0
#     return np.nan_to_num(sil_samples)


# def _silhouette_reduce(D_chunk, start, labels, label_freqs):
#     """Accumulate silhouette statistics for vertical chunk of X
#     Parameters
#     ----------
#     D_chunk : shape (n_chunk_samples, n_samples)
#         precomputed distances for a chunk
#     start : int
#         first index in chunk
#     labels : array, shape (n_samples,)
#         corresponding cluster labels, encoded as {0, ..., n_clusters-1}
#     label_freqs : array
#         distribution of cluster labels in ``labels``
#     """
#     # accumulate distances from each sample to each cluster
#     clust_dists = np.zeros((len(D_chunk), len(label_freqs)), dtype=D_chunk.dtype)
#     for i in range(len(D_chunk)):
#         clust_dists[i] += np.bincount(
#             labels, weights=D_chunk[i], minlength=len(label_freqs)
#         )

#     # intra_index selects intra-cluster distances within clust_dists
#     intra_index = (np.arange(len(D_chunk)), labels[start : start + len(D_chunk)])
#     # intra_clust_dists are averaged over cluster size outside this function
#     intra_clust_dists = clust_dists[intra_index]
#     # of the remaining distances we normalise and extract the minimum
#     clust_dists[intra_index] = np.inf
#     clust_dists /= label_freqs
#     inter_clust_dists = clust_dists.min(axis=1)
#     return intra_clust_dists, inter_clust_dists


# def _check_chunk_size(reduced, chunk_size):
#     """Checks chunk is a sequence of expected size or a tuple of same"""
#     if reduced is None:
#         return
#     is_tuple = isinstance(reduced, tuple)
#     if not is_tuple:
#         reduced = (reduced,)
#     if any(isinstance(r, tuple) or not hasattr(r, "__iter__") for r in reduced):
#         raise TypeError(
#             "reduce_func returned %r. "
#             "Expected sequence(s) of length %d."
#             % (reduced if is_tuple else reduced[0], chunk_size)
#         )
#     if any(len(r) != chunk_size for r in reduced):
#         actual_size = tuple(len(r) for r in reduced)
#         raise ValueError(
#             "reduce_func returned object of length %s. "
#             "Expected same length as input: %d."
#             % (actual_size if is_tuple else actual_size[0], chunk_size)
#         )


# def pairwise_distances_chunked_cuda(X, reduce_func=None, verbose=False):
#     """Generate a distance matrix chunk by chunk with optional reduction
#     In cases where not all of a pairwise distance matrix needs to be stored at
#     once, this is used to calculate pairwise distances in
#     ``working_memory``-sized chunks.  If ``reduce_func`` is given, it is run
#     on each chunk and its return values are concatenated into lists, arrays
#     or sparse matrices.
#     Parameters
#     ----------
#     X : array [n_samples_a, n_samples_a] if metric == "precomputed", or,
#         [n_samples_a, n_features] otherwise
#         Array of pairwise distances between samples, or a feature array.
#     Y : array [n_samples_b, n_features], optional
#         An optional second feature array. Only allowed if
#         metric != "precomputed".
#     reduce_func : callable, optional
#         The function which is applied on each chunk of the distance matrix,
#         reducing it to needed values.  ``reduce_func(D_chunk, start)``
#         is called repeatedly, where ``D_chunk`` is a contiguous vertical
#         slice of the pairwise distance matrix, starting at row ``start``.
#         It should return one of: None; an array, a list, or a sparse matrix
#         of length ``D_chunk.shape[0]``; or a tuple of such objects. Returning
#         None is useful for in-place operations, rather than reductions.
#         If None, pairwise_distances_chunked returns a generator of vertical
#         chunks of the distance matrix.
#     metric : string, or callable
#         The metric to use when calculating distance between instances in a
#         feature array. If metric is a string, it must be one of the options
#         allowed by scipy.spatial.distance.pdist for its metric parameter, or
#         a metric listed in pairwise.PAIRWISE_DISTANCE_FUNCTIONS.
#         If metric is "precomputed", X is assumed to be a distance matrix.
#         Alternatively, if metric is a callable function, it is called on each
#         pair of instances (rows) and the resulting value recorded. The callable
#         should take two arrays from X as input and return a value indicating
#         the distance between them.
#     n_jobs : int or None, optional (default=None)
#         The number of jobs to use for the computation. This works by breaking
#         down the pairwise matrix into n_jobs even slices and computing them in
#         parallel.
#         ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
#         ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
#         for more details.
#     working_memory : int, optional
#         The sought maximum memory for temporary distance matrix chunks.
#         When None (default), the value of
#         ``sklearn.get_config()['working_memory']`` is used.
#     `**kwds` : optional keyword parameters
#         Any further parameters are passed directly to the distance function.
#         If using a scipy.spatial.distance metric, the parameters are still
#         metric dependent. See the scipy docs for usage examples.
#     Yields
#     ------
#     D_chunk : array or sparse matrix
#         A contiguous slice of distance matrix, optionally processed by
#         ``reduce_func``.
#     Examples
#     --------
#     Without reduce_func:
#     >>> import numpy as np
#     >>> from sklearn.metrics import pairwise_distances_chunked
#     >>> X = np.random.RandomState(0).rand(5, 3)
#     >>> D_chunk = next(pairwise_distances_chunked(X))
#     >>> D_chunk
#     array([[0.  ..., 0.29..., 0.41..., 0.19..., 0.57...],
#            [0.29..., 0.  ..., 0.57..., 0.41..., 0.76...],
#            [0.41..., 0.57..., 0.  ..., 0.44..., 0.90...],
#            [0.19..., 0.41..., 0.44..., 0.  ..., 0.51...],
#            [0.57..., 0.76..., 0.90..., 0.51..., 0.  ...]])
#     Retrieve all neighbors and average distance within radius r:
#     >>> r = .2
#     >>> def reduce_func(D_chunk, start):
#     ...     neigh = [np.flatnonzero(d < r) for d in D_chunk]
#     ...     avg_dist = (D_chunk * (D_chunk < r)).mean(axis=1)
#     ...     return neigh, avg_dist
#     >>> gen = pairwise_distances_chunked(X, reduce_func=reduce_func)
#     >>> neigh, avg_dist = next(gen)
#     >>> neigh
#     [array([0, 3]), array([1]), array([2]), array([0, 3]), array([4])]
#     >>> avg_dist
#     array([0.039..., 0.        , 0.        , 0.039..., 0.        ])
#     Where r is defined per sample, we need to make use of ``start``:
#     >>> r = [.2, .4, .4, .3, .1]
#     >>> def reduce_func(D_chunk, start):
#     ...     neigh = [np.flatnonzero(d < r[i])
#     ...              for i, d in enumerate(D_chunk, start)]
#     ...     return neigh
#     >>> neigh = next(pairwise_distances_chunked(X, reduce_func=reduce_func))
#     >>> neigh
#     [array([0, 3]), array([0, 1]), array([2]), array([0, 3]), array([4])]
#     Force row-by-row generation by reducing ``working_memory``:
#     >>> gen = pairwise_distances_chunked(X, reduce_func=reduce_func,
#     ...                                  working_memory=0)
#     >>> next(gen)
#     [array([0, 3])]
#     >>> next(gen)
#     [array([0, 1])]
#     """
#     X = X.astype(np.float32)
#     n_samples_X = len(X)
#     Y = X
#     # We get as many rows as possible within our working_memory budget to
#     # store len(Y) distances in each row of output.
#     #
#     # Note:
#     #  - this will get at least 1 row, even if 1 row of distances will
#     #    exceed working_memory.
#     #  - this does not account for any temporary memory usage while
#     #    calculating distances (e.g. difference of vectors in manhattan
#     #    distance.
#     chunk_n_rows = get_chunk_n_rows(
#         row_bytes=8 * len(Y), max_n_rows=n_samples_X, working_memory=None
#     )
#     slices = gen_batches(n_samples_X, chunk_n_rows)

#     X_full = torch.tensor(X).cuda()
#     Xnorms = torch.norm(X_full, dim=1, keepdim=True) ** 2
#     for sl in slices:
#         if verbose:
#             print(sl)
#         if sl.start == 0 and sl.stop == n_samples_X:
#             X_chunk = X  # enable optimised paths for X is Y
#         else:
#             X_chunk = X[sl]
#         pX = torch.tensor(X_chunk).cuda()
#         d2 = Xnorms[sl] - 2 * torch.matmul(pX, X_full.t()) + Xnorms.t()
#         d2 = torch.sqrt(torch.nn.functional.relu(d2)).cpu().numpy()
#         d2.flat[sl.start :: len(X) + 1] = 0
#         D_chunk = d2
#         if reduce_func is not None:
#             chunk_size = D_chunk.shape[0]
#             D_chunk = reduce_func(D_chunk, sl.start)
#             _check_chunk_size(D_chunk, chunk_size)
#         yield D_chunk
