import os

# Must mock the C module for read the docs as they have
# no support for compiling C code
##on_rtd = os.environ.get('READTHEDOCS') == 'True'
##if on_rtd:
##    
##else:
##    from _coranking import metrics_cy

import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.metrics import normalized_mutual_info_score, silhouette_score
from sklearn.cluster import KMeans

from unittest.mock import MagicMock
metrics_cy = MagicMock()

import _coranking


def get_coranking(x, z):
    return _coranking.coranking_matrix(x, z)

def qnx_crm(crm, k):
    """ Average Normalized Agreement Between K-ary Neighborhoods (QNX)
    # QNX measures the degree to which an embedding preserves the local
    # neighborhood around each observation. For a value of K, the K closest
    # neighbors of each observation are retrieved in the input and output space.
    # For each observation, the number of shared neighbors can vary between 0
    # and K. QNX is simply the average value of the number of shared neighbors,
    # normalized by K, so that if the neighborhoods are perfectly preserved, QNX
    # is 1, and if there is no neighborhood preservation, QNX is 0.
    #
    # For a random embedding, the expected value of QNX is approximately
    # K / (N - 1) where N is the number of observations. Using RNX
    # (\code{rnx_crm}) removes this dependency on K and the number of
    # observations.
    #
    # @param crm Co-ranking matrix. Create from a pair of distance matrices with
    # \code{coranking_matrix}.
    # @param k Neighborhood size.
    # @return QNX for \code{k}.
    # @references
    # Lee, J. A., & Verleysen, M. (2009).
    # Quality assessment of dimensionality reduction: Rank-based criteria.
    # \emph{Neurocomputing}, \emph{72(7)}, 1431-1443.

    Python reimplmentation of code by jlmelville
    (https://github.com/jlmelville/quadra/blob/master/R/neighbor.R)
    """
    qnx_crm_sum = np.sum(crm[:k, :k])
    return qnx_crm_sum / (k * len(crm))

def rnx_crm(crm, k):
    """ Rescaled Agreement Between K-ary Neighborhoods (RNX)
    # RNX is a scaled version of QNX which measures the agreement between two
    # embeddings in terms of the shared number of k-nearest neighbors for each
    # observation. RNX gives a value of 1 if the neighbors are all preserved
    # perfectly and a value of 0 for a random embedding.
    #
    # @param crm Co-ranking matrix. Create from a pair of distance matrices with
    # \code{coranking_matrix}.
    # @param k Neighborhood size.
    # @return RNX for \code{k}.
    # @references
    # Lee, J. A., Renard, E., Bernard, G., Dupont, P., & Verleysen, M. (2013).
    # Type 1 and 2 mixtures of Kullback-Leibler divergences as cost functions in
    # dimensionality reduction based on similarity preservation.
    # \emph{Neurocomputing}, \emph{112}, 92-108.

    Python reimplmentation of code by jlmelville
    (https://github.com/jlmelville/quadra/blob/master/R/neighbor.R)
    """
    n = len(crm)
    return ((qnx_crm(crm, k) * (n - 1)) - k) / (n - 1 - k)


# @numba.njit(fastmath=True)
def rnx_auc_crm(crm):
    """ Area Under the RNX Curve
    # The RNX curve is formed by calculating the \code{rnx_crm} metric for
    # different sizes of neighborhood. Each value of RNX is scaled according to
    # the natural log of the neighborhood size, to give a higher weight to smaller
    # neighborhoods. An AUC of 1 indicates perfect neighborhood preservation, an
    # AUC of 0 is due to random results.
    #
    # param crm Co-ranking matrix.
    # return Area under the curve.
    # references
    # Lee, J. A., Peluffo-Ordo'nez, D. H., & Verleysen, M. (2015).
    # Multi-scale similarities in stochastic neighbour embedding: Reducing
    # dimensionality while preserving both local and global structure.
    # \emph{Neurocomputing}, \emph{169}, 246-261.

    Python reimplmentation of code by jlmelville
    (https://github.com/jlmelville/quadra/blob/master/R/neighbor.R)

    from https://timsainburg.com/coranking-matrix-python-numba.html
    """
    n = len(crm)
    num = 0
    den = 0

    qnx_crm_sum = 0
    for k in range(1, n - 2):
        # for k in (range(1, n - 2)):
        qnx_crm_sum += np.sum(crm[(k - 1), :k]) + np.sum(crm[:k, (k - 1)]) - crm[(k - 1), (k - 1)]
        qnx_crm = qnx_crm_sum / (k * len(crm))
        rnx_crm = ((qnx_crm * (n - 1)) - k) / (n - 1 - k)
        num += rnx_crm / k
        den += 1 / k
    return num / den


def trustworthiness(Q, min_k=1, max_k=None):
    """Compute the trustwortiness metric over a range of K values.

    :param Q: coranking matrix
    :param min_k: the lowest K value to compute. Default 1.
    :param max_k: the highest K value to compute. If None the range of values
        will be computer from min_k to n-1

    :returns: array of size min_k - max_k with the corresponding
        trustworthiness values.
    """
    if not isinstance(Q, np.int64):
        Q = Q.astype(np.int64)

    if max_k is None:
        max_k = Q.shape[0]-1

    result = [metrics_cy.trustworthiness(Q, x) for x in range(min_k, max_k)]
    return np.array(result)


def continuity(Q, min_k=1, max_k=None):
    """Compute the continuity metric over a range of K values.

    :param Q: coranking matrix
    :param min_k: the lowest K value to compute. Default 1.
    :param max_k: the highest K value to compute. If None the range of values
        will be computer from min_k to n-1

    :returns: array of size min_k - max_k with the corresponding continuity
        values.
    """
    if not isinstance(Q, np.int64):
        Q = Q.astype(np.int64)

    if max_k is None:
        max_k = Q.shape[0]-1

    result = [metrics_cy.continuity(Q, x) for x in range(min_k, max_k)]
    return np.array(result)


def LCMC(Q, min_k=1, max_k=None):
    """Compute the local continuity meta-criteria (LCMC) metric over a range of
    K values.

    :param Q: coranking matrix
    :param min_k: the lowest K value to compute. Default 1.
    :param max_k: the highest K value to compute. If None the range of values
        will be computer from min_k to n-1

    :returns: array of size min_k - max_k with the corresponding LCMC values.
    """
    if not isinstance(Q, np.int64):
        Q = Q.astype(np.int64)

    if max_k is None:
        max_k = Q.shape[0]-1

    result = [metrics_cy.LCMC(Q, x) for x in range(min_k, max_k)]
    return np.array(result)


def _check_square_matrix(M):
    if M.shape[0] != M.shape[1]:
        msg = "Expected square matrix, but matrix had dimensions (%d, %d)" % M.shape
        raise RuntimeError(msg)


def get_ca_knn(z, y, k):
    X_tr, X_te, y_tr, y_te = train_test_split(z, y, test_size=0.3)
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_tr, y_tr)
    accuracy = knn.score(X_te, y_te)
    return accuracy


def get_kmeans(z, n_clusters=10, n_init=10):
    kmeans = KMeans(n_clusters=n_clusters, n_init=n_init)
    y_pred = kmeans.fit_predict(z)
    return y_pred


def get_nmi(y, y_pred):
    # 计算NMI
    nmi = normalized_mutual_info_score(y, y_pred)
    return nmi


def get_sc(z, y_pred):
    sc = silhouette_score(z, y_pred)
    return sc


def calculate_overlap_ratio(high_dim_data, low_dim_data, k):
    """
    high_dim_data: torch data (40000*784)
    low_dim_data: z(40000*2)
    """
    high_dim_data = high_dim_data

    # 在高维空间中找到每个点的 k 近邻
    nn_high = NearestNeighbors(n_neighbors=k + 1).fit(high_dim_data)
    distances_high, indices_high = nn_high.kneighbors(high_dim_data)

    # 在低维空间中找到每个点的 k 近邻
    nn_low = NearestNeighbors(n_neighbors=k + 1).fit(low_dim_data)
    distances_low, indices_low = nn_low.kneighbors(low_dim_data)

    overlap_ratios = []

    # 对于每个点，计算两个集合的重叠比例
    for i in range(len(high_dim_data)):
        A = set(indices_high[i][1:])  # 跳过自身
        B = set(indices_low[i][1:])  # 跳过自身
        overlap = len(A.intersection(B)) / k
        overlap_ratios.append(overlap)

    return np.mean(overlap_ratios)


def get_npa(x, z, k):
    npa = calculate_overlap_ratio(x, z, k=k)
    return npa


def calculate_all_metrics(x, z, y):
    q = get_coranking(x, z)

    # trust = trustworthiness(q, min_k=1, max_k=50)
    # cont = continuity(q, min_k=1, max_k=50)
    # lcmc = LCMC(q, min_k=1, max_k=50)
    coranking_auc = rnx_auc_crm(q)

    ks = [1, 10, 50]
    ca_results = []
    for k in ks:
        ca = get_ca_knn(z, y, k)
        ca_results.append(ca)

    n_class = len(set(y))
    y_pred = get_kmeans(z, n_class, n_class)

    nmi = get_nmi(y, y_pred)
    sc = get_sc(z, y_pred)

    npa_results = []
    for k in ks:
        npa = get_npa(x, z, k)
        npa_results.append(npa)
    return [coranking_auc, nmi, sc] + ca_results + npa_results


def get_nmi_sc(z, y):
    # print(y)
    n_class = len(set(y))
    y_pred = get_kmeans(z, n_class, n_class)

    nmi = get_nmi(y, y_pred)
    sc = get_sc(z, y_pred)
    return nmi, sc

