import os

# Must mock the C module for read the docs as they have
# no support for compiling C code
##on_rtd = os.environ.get('READTHEDOCS') == 'True'
##if on_rtd:
##    
##else:
##    from _coranking import metrics_cy

import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.metrics import normalized_mutual_info_score, silhouette_score
from sklearn.cluster import KMeans

from unittest.mock import MagicMock
metrics_cy = MagicMock()


def get_ca_knn(z, y, k):
    X_tr, X_te, y_tr, y_te = train_test_split(z, y, test_size=0.3)
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_tr, y_tr)
    accuracy = knn.score(X_te, y_te)
    return accuracy


def get_kmeans(z, n_clusters=10, n_init=10):
    kmeans = KMeans(n_clusters=n_clusters, n_init=n_init)
    y_pred = kmeans.fit_predict(z)
    return y_pred


def get_nmi(y, y_pred):
    # 计算NMI
    nmi = normalized_mutual_info_score(y, y_pred)
    return nmi


def get_sc(z, y_pred):
    sc = silhouette_score(z, y_pred)
    return sc


def calculate_overlap_ratio(high_dim_data, low_dim_data, k):
    """
    high_dim_data: torch data (40000*784)
    low_dim_data: z(40000*2)
    """
    high_dim_data = high_dim_data

    # 在高维空间中找到每个点的 k 近邻
    nn_high = NearestNeighbors(n_neighbors=k + 1).fit(high_dim_data)
    distances_high, indices_high = nn_high.kneighbors(high_dim_data)

    # 在低维空间中找到每个点的 k 近邻
    nn_low = NearestNeighbors(n_neighbors=k + 1).fit(low_dim_data)
    distances_low, indices_low = nn_low.kneighbors(low_dim_data)

    overlap_ratios = []

    # 对于每个点，计算两个集合的重叠比例
    for i in range(len(high_dim_data)):
        A = set(indices_high[i][1:])  # 跳过自身
        B = set(indices_low[i][1:])  # 跳过自身
        overlap = len(A.intersection(B)) / k
        overlap_ratios.append(overlap)

    return np.mean(overlap_ratios)


def get_npa(x, z, k):
    npa = calculate_overlap_ratio(x, z, k=k)
    return npa


def get_nmi_sc(z, y):
    n_class = len(set(y))
    y_pred = get_kmeans(z, n_class, n_class)

    nmi = get_nmi(y, y_pred)
    sc = get_sc(z, y_pred)
    return nmi, sc

