import numpy as np

from sklearn.metrics import auc
from scipy.stats import rankdata
from scipy.stats import kendalltau
from sklearn.preprocessing import MinMaxScaler

from multiprocessing import Pool, cpu_count
import tqdm

def compute_kendalltau_distance(pair):
    i, j, X = pair
    tau, _ = kendalltau(X[:, i], X[:, j])
    return i, j, (tau if not np.isnan(tau) else 0.0)

def pairwise_kendalltau_matrix_multiprocessing(X):
    n = X.shape[1]
    distance_matrix = np.zeros((n, n))
    pairs = [(i, j, X) for i in range(n) for j in range(i + 1, n)]

    with Pool(processes=10) as pool:
        for i, j, distance in tqdm.tqdm(pool.imap_unordered(compute_kendalltau_distance, pairs), total=len(pairs)):
            distance_matrix[i, j] = distance
            distance_matrix[j, i] = distance  # Symmetric matrix
    return distance_matrix


def hits(score_mat):
    rank_mat = rankdata(score_mat, axis=0)
    inv_rank_mat = 1 / rank_mat

    n_samples, n_models = score_mat.shape[0], score_mat.shape[1]

    hub_vec = np.full([n_models, 1], 1/ n_models)
    auth_vec = np.zeros([n_samples, 1])

    hub_vec_list = []
    auth_vec_list = []

    hub_vec_list.append(hub_vec)
    auth_vec_list.append(auth_vec)

    for i in range(500):
        auth_vec = np.dot(inv_rank_mat, hub_vec)
        auth_vec = auth_vec/ np.linalg.norm(auth_vec)

        # update hub_vec
        hub_vec = np.dot(inv_rank_mat.T, auth_vec)
        hub_vec = hub_vec/ np.linalg.norm(hub_vec)

        # stopping criteria
        auth_diff = auth_vec - auth_vec_list[-1]
        hub_diff = hub_vec - hub_vec_list[-1]

        # print(auth_diff.sum(), auth_diff.mean(), auth_diff.std())
        # print(hub_diff.sum(), hub_diff.mean(), hub_diff.std())
        # print()

        if np.abs(auth_diff.sum()) <= 1e-10 and np.abs(auth_diff.mean()) <= 1e-10 and np.abs(
                hub_diff.sum()) <= 1e-10 and np.abs(hub_diff.mean()) <= 1e-10:
            print('break at', i)
            break

        auth_vec_list.append(auth_vec)
        hub_vec_list.append(hub_vec)
    score = np.max(hub_vec_list[-1])
    selected_idx = np.argmax(hub_vec_list[-1])
    return score, selected_idx


def mc(score_mat):
    n_samples, n_models = score_mat.shape[0], score_mat.shape[1]
    output_mat_r = rankdata(score_mat, axis=0)
    output_mat = MinMaxScaler().fit_transform(output_mat_r)

    # similar_mat = np.full((n_models, n_models), 1).astype(float)
    # for k in range(n_models):
    #     for j in range(n_models):
    #         corr = kendalltau(output_mat[:, k], output_mat[:, j])[0]
    #         similar_mat[k, j] = corr
    similar_mat = pairwise_kendalltau_matrix_multiprocessing(output_mat)

    B = (similar_mat + similar_mat.T) / 2
    # fix nan problem
    B = np.nan_to_num(B)
    similarity = (np.sum(B, axis=1) - 1) / (n_models - 1)
    score = np.max(similarity)
    selected_idx = np.argmax(similarity)
    return score, selected_idx


def em(t, t_max, volume_support, s_unif, s_X, n_generated):
    EM_t = np.zeros(t.shape[0])
    n_samples = s_X.shape[0]
    s_X_unique = np.unique(s_X)
    EM_t[0] = 1.
    for u in s_X_unique:
        # if (s_unif >= u).sum() > n_generated / 1000:
        EM_t = np.maximum(EM_t, 1. / n_samples * (s_X > u).sum() -
                          t * (s_unif > u).sum() / n_generated
                          * volume_support)
    amax = np.argmax(EM_t <= t_max) + 1
    if amax == 1:
        # print ('\n failed to achieve t_max \n')
        amax = -1
    AUC = auc(t[:amax], EM_t[:amax])
    return AUC, EM_t


def mv(axis_alpha, volume_support, s_unif, s_X, n_generated):
    n_samples = s_X.shape[0]
    s_X_argsort = s_X.argsort()
    mass = 0
    cpt = 0
    u = s_X[s_X_argsort[-1]]
    mv = np.zeros(axis_alpha.shape[0])
    for i in range(axis_alpha.shape[0]):
        # pdb.set_trace()
        while mass < axis_alpha[i]:
            cpt += 1
            u = s_X[s_X_argsort[-cpt]]
            mass = 1. / n_samples * cpt  # sum(s_X > u)
        mv[i] = float((s_unif >= u).sum()) / n_generated * volume_support
    return auc(axis_alpha, mv), mv


def get_em_mv_original(score, unif_score, X, alpha_min=0.9, alpha_max=0.999,
                       n_generated=10000, t_max=0.9):
    n_features = X.shape[1]

    lim_inf = X.min(axis=0)
    lim_sup = X.max(axis=0)
    volume_support = (lim_sup - lim_inf).prod()
    # if volume_support == np.inf:
    #     volume_support = 100000
    if volume_support == 0:
        volume_support = ((lim_sup - lim_inf) ).prod() + 0.00001
    if volume_support == np.inf:
        volume_support = 1000
    # print(volume_support)
    t = np.arange(0, 100 / volume_support, 0.01 / volume_support)
    axis_alpha = np.arange(alpha_min, alpha_max, 0.0001)

    # clf.fit(X)
    # s_X_clf = clf.decision_scores_ * -1
    s_X_clf = score * -1

    s_unif_clf = unif_score * -1

    em_clf = em(t, t_max, volume_support, s_unif_clf,
                s_X_clf, n_generated)[0]
    mv_clf = mv(axis_alpha, volume_support, s_unif_clf,
                s_X_clf, n_generated)[0]
    return em_clf, mv_clf



def get_unif_score(X, clf, n_generated=10000):
    n_features = X.shape[1]

    lim_inf = X.min(axis=0)
    lim_sup = X.max(axis=0)

    unif = np.random.uniform(lim_inf, lim_sup,
                             size=(n_generated, n_features))
    return clf.decision_function(unif)