from pathlib import Path

import numpy as np
import pandas as pd
import scipy.stats as stats

from src.figures import merge_results as merge_results
import cfg


def table_al_metric_comparison(data, folder, tag='evaluation'):
    # create dict with performance, nr_added_training_samples
    dict = {}
    for _tuple in data:
        # get current dataset and hash
        if folder == 'single_label':
            dataset, hashs, sampling_methods = _tuple
        else:
            dataset, hashs, sampling_methods, _, _ = _tuple
        print(f'Load Dataset: {dataset}')

        if dataset not in data:
            dict[dataset] = {}

        # read dataset
        df = pd.read_pickle(Path(cfg.path_exp, folder, dataset, 'results_merged.pkl'))

        for hash_ in hashs:
            for sampling_method in sampling_methods:
                # get row index sampling method and random
                index = int(df.query('hash == @hash_ and sampling_method == @sampling_method').index[0])

                # get n_training_samples and performances
                nr_training_samples = df.loc[index, 'nr_training_samples']
                nr_added_training_samples = nr_training_samples - nr_training_samples[0]
                performance = df.loc[index, tag + '_f1_macro']
                classes = df.loc[index, 'classes']
                ceiling_f1_macro = df.loc[index, 'evaluation_ceiling_f1_macro']

                dict[dataset][sampling_method] = {'x': nr_added_training_samples,
                                                  'y': performance,
                                                  'label': classes,
                                                  'evaluation_ceiling_f1_macro': ceiling_f1_macro}

    # compute evaluation values for n = 600 and n = 1360
    n1 = 700
    n2 = 1360
    if folder == 'single_label':
        n_all = np.arange(40, 1980, 20)
    else:
        n_all = np.arange(40, 1380, 20)

    # iterate over datasets and sampling methods
    for dataset in dict.keys():
        for sampling_method in dict[dataset].keys():
            print(f'Process: {dataset} | {sampling_method}')
            x = dict[dataset][sampling_method]['x']
            y = dict[dataset][sampling_method]['y']
            x_rand = dict[dataset]['random']['x']
            y_rand = dict[dataset]['random']['y']

            for n in n_all:
                indices = x <= n
                x_n = x[indices]
                y_n = y[indices]

                x_rand_n = x_rand[indices]
                y_rand_n = y_rand[indices]

                # mean of the learning curve
                lc_mean = np.nanmean(y_n)

                # area under learning curve
                valid_indices = ~np.isnan(y_n)
                x_valid = x_n[valid_indices]
                y_valid = y_n[valid_indices]
                lc_aulc = np.trapezoid(y_valid, x_valid)

                valid_indices = ~np.isnan(y_rand_n)
                x_rand_valid = x_rand_n[valid_indices]
                y_rand_valid = y_rand_n[valid_indices]
                lc_rand_aulc = np.trapezoid(y_rand_valid, x_rand_valid)

                # cut points
                valid_indices = ~(np.isnan(y_n) | np.isnan(y_rand_n))
                y_valid = y_n[valid_indices]
                y_rand_valid = y_rand_n[valid_indices]
                cut_points_p_al_better_rand = compute_cutpoints_p_value(y_valid, y_rand_valid)
                cut_points_p_rand_better_al = compute_cutpoints_p_value(y_rand_valid, y_valid)

                # speedup factor (fixed and general)
                valid_indices = ~(np.isnan(y_n) | np.isnan(y_rand_n))
                x_valid = x_n[valid_indices]
                y_valid = y_n[valid_indices]
                y_rand_valid = y_rand_n[valid_indices]

                df_data = {'evaluation_ceiling_f1_macro': [dict[dataset][sampling_method]['evaluation_ceiling_f1_macro'],
                                                           dict[dataset]['random']['evaluation_ceiling_f1_macro']],
                           'evaluation_f1_macro': [y_n,
                                                   y_rand_n],
                           'nr_training_samples': [x_n,
                                                   x_rand_n],
                           'sampling_method': [sampling_method,
                                               'random']
                           }
                df_for_sf_fix = pd.DataFrame(data=df_data)
                df_for_sf_fix = merge_results.get_sf_properties(df_for_sf_fix, [0, 1], a0_equal_0=True)

                df_for_sf_general = pd.DataFrame(data=df_data)
                df_for_sf_general = merge_results.get_sf_properties(df_for_sf_general, [0, 1])

                # save results
                dict[dataset][sampling_method][n] = {}
                dict[dataset][sampling_method][n]['LC_mean'] = lc_mean
                dict[dataset][sampling_method][n]['LC_aulc'] = lc_aulc
                dict[dataset][sampling_method][n]['LC_aulc_frac'] = lc_aulc / lc_rand_aulc
                dict[dataset][sampling_method][n]['cut_points_p_rand_better_al'] = cut_points_p_rand_better_al
                dict[dataset][sampling_method][n]['cut_points_p_al_better_rand'] = cut_points_p_al_better_rand
                dict[dataset][sampling_method][n]['S_fix'] = df_for_sf_fix.loc[0, 'sf_p1']
                dict[dataset][sampling_method][n]['S_general'] = df_for_sf_general.loc[0, 'sf_best_p']

    # print table
    s_start = ('\\begin{table*} \n'
               '\t \\caption{Active learning performance metrics on selected single-label datasets. The performance metrics were computed} \n'
               '\t \\label{tab:exp_performance} \n'
               '\t \\begin{tabular}{llllllllllll} \n'
               '\t\t \\toprule \n'
               '\t\t & & \multicolumn{2}{c}{LC Mean [\\%]} & '
               '\multicolumn{2}{c}{$\\frac{\\text{AULC}_{qm}}{\\text{AULC}_{rand}}$} & '
               '\multicolumn{2}{c}{Cut Points (p-value)} & '
               '\multicolumn{2}{c}{Speedup Factor fixed} &  '
               '\multicolumn{2}{c}{Speedup Factor} \\\\ \n'
               '\t\t \cmidrule(lr){3-4} \n'
               '\t\t \cmidrule(lr){5-6} \n'
               '\t\t \cmidrule(lr){7-8} \n'
               '\t\t \cmidrule(lr){9-10} \n'
               '\t\t \cmidrule(lr){11-12} \n'
               f'\t\t Dataset & Label & {n1} & {n2} & {n1} & {n2} & {n1} & {n2} & {n1} & {n2} & {n1} & {n2} \\\\ \n'
               '\t\t \\midrule \n')

    s_body = ''
    for dataset in dict.keys():
        sampling_method = [key for key in dict[dataset] if key != 'random'][0]
        s_row = (f'\t\t {cfg.figures[dataset]["text"]} & '
                 f' {dict[dataset][sampling_method]["label"][0][6:]} & '
                 f' {100 * dict[dataset][sampling_method][n1]["LC_mean"]:.1f} &'
                 f' {100 * dict[dataset][sampling_method][n2]["LC_mean"]:.1f} &'
                 f' {dict[dataset][sampling_method][n1]["LC_aulc_frac"]:.3f} &'
                 f' {dict[dataset][sampling_method][n2]["LC_aulc_frac"]:.3f} &'
                 f' {dict[dataset][sampling_method][n1]["cut_points_p_al_better_rand"]:.6f} &'
                 f' {dict[dataset][sampling_method][n2]["cut_points_p_al_better_rand"]:.6f} &'
                 f' {dict[dataset][sampling_method][n1]["S_fix"]:.2f} &'
                 f' {dict[dataset][sampling_method][n2]["S_fix"]:.2f} &'
                 f' {dict[dataset][sampling_method][n1]["S_general"]:.2f} & '
                 f' {dict[dataset][sampling_method][n2]["S_general"]:.2f}\\\\ \n')
        s_body = s_body + s_row

    s_end = ('\t\t \\bottomrule \n'
             '\t \\end{tabular} \n'
             '\\end{table*}')

    s = s_start + s_body + s_end
    print(s)
    return dict

def compute_cutpoints_p_value(y_alpha, y_beta):
    # cutpoint score
    c_score = y_alpha - y_beta
    # ranking
    ranking = 1 + np.argsort(np.argsort(c_score)[::-1])
    # ideal ranking
    ideal_ranking = np.sort(ranking)[::-1]
    # L value
    l_value = np.sum(ranking * ideal_ranking)
    # approximated p-value
    m = 1
    k = len(ideal_ranking)
    z = (12 * (l_value - 0.5) - 3 * m * k * (k + 1) ** 2) / (k * (k + 1) * np.sqrt(m * (k - 1)))
    p_approx = stats.norm.sf(z)
    return p_approx






if __name__ == '__main__':
    datasets_ml_rare_freq = [('carina', ['1e38f5524602ce1e190fe32b26931a59'], ['random', 'multilabel_simple_crw'], 98663, 'g'),
                             ('mscoco', ['2f0debf09caf973af3512632ff386f88'], ['random', 'ratio_max'], 118287, 'car'),
                             ('reuters', ['4fbeb7601b0e58d7e5942b6aa618257e'], ['random', 'kmeans'], 53571, 'interest'),
                             ('scene', ['a2d204afbc000092ba4f929212989359'], ['random', 'bald'], 2407, 'Sunset')]
    table_al_metric_comparison(datasets_ml_rare_freq, 'multi_label')