"""
Baseline separability metrics: Calinksi-Harabasz, Thornton index, ROC-AUC
"""

import numpy as np
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    balanced_accuracy_score,
    calinski_harabasz_score,
    make_scorer,
    roc_auc_score,
)
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import NearestNeighbors


def kmeans_ch_score(data, n_clusters: int = 3, random_state: int = 0):
    """
    Returns a score for separability; greater score means
    better defined clusters.

    The score is defined as ratio of the sum of between-cluster
    dispersion over within-cluster dispersion (this is the definition
    of the Calinksi-Harabasz index).

    Args:
        data: np.array, the data to assess separability over.
        n_clusters: int, used by KMeans for clustering.
        random_state: int, for reproducibility.

    Returns:
        A non-negative score of separability. Larger score means
        more separability. This score is unbounded.
    """

    kmeans = KMeans(
        n_clusters=n_clusters, random_state=random_state, n_init="auto"
    ).fit(data)
    return calinski_harabasz_score(data, kmeans.labels_)


def thornton_separability_index(data, labels, n_neighbors=5):
    """
    Estimates thornton's separability index, capturing the fraction
    that the label of a data point agrees with the labels of its nearest
    neighbors.

    This function works for up-to multi-class, but not multi-label case.

    Args:
        data: np.array, contains the data we want to check for
            separability.
        labels: array of ints, labels corresponding to the data
        n_neighbors:int, number of nearest neighbors to use when calculating.

    Returns:
        A score between 0 and 1 capturing separability.
        Interpret the score as the probability that a point
        and its nearest neighbors share the same label.

    """
    nearest_neighbors = NearestNeighbors(n_neighbors=n_neighbors)
    nearest_neighbors.fit(data)

    def f(index_: int) -> float:
        _, indexes = nearest_neighbors.kneighbors(
            data[index_, :].reshape(1, -1)
        )
        indexes = indexes.tolist()[0]
        if index_ in indexes:
            indexes.remove(index_)

        if len(indexes) == 0:
            return 0

        return np.mean([labels[index_] == labels[i] for i in indexes])

    same_label_count = [f(i) for i in range(len(data))]

    return np.mean(same_label_count)


def balanced_accuracy_index(
    data, labels, n_splits=20, **logistic_regression_args
) -> float:
    """
    Estimates the balanced-accuracy over data splits to assess separability.

    A logistic regression model is used and repeatedly fit
    on different splits of the data.

    Args:
        data: np.array, contains the data we want to check for
            separability.
        labels: array of ints, labels corresponding to the data.
        n_splits: int, how many  splits to use for cross-validation.
        **logistic_regression_args: any other keyword-style arguments
            are passed to the logistic regression model.

    Returns:
        The average balanced-accuracy over the splits.
    """

    scoring = make_scorer(balanced_accuracy_score, needs_proba=False)
    clf = LogisticRegression(**logistic_regression_args).fit(data, labels)

    return np.mean(
        cross_val_score(clf, data, labels, scoring=scoring, cv=n_splits)
    )


def roc_auc_index(
    data, labels, n_splits=20, **logistic_regression_args
) -> float:
    """
    Estimates the roc-auc over data splits to assess separability.

    A logistic regression model is used and repeatedly fit
    on different splits of the data.

    Args:
        data: np.array, contains the data we want to check for
            separability.
        labels: array of ints, labels corresponding to the data.
        n_splits: int, how many  splits to use for cross-validation.
        **logistic_regression_args: any other keyword-style arguments
            are passed to the logistic regression model.

    Returns:
        The average roc_auc_score over the splits.

    """

    multi_class = len(set(labels)) != 2

    scoring = (
        make_scorer(multiclass_roc_auc_score, needs_proba=True)
        if multi_class
        else "roc_auc"
    )
    clf = LogisticRegression(**logistic_regression_args).fit(data, labels)
    return np.mean(
        cross_val_score(clf, data, labels, scoring=scoring, cv=n_splits)
    )


def multiclass_roc_auc_score(y_true, y_pred, average="macro"):
    return roc_auc_score(y_true, y_pred, average=average, multi_class="ovo")
