import os.path
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.stats as stats

from src.figures import merge_results as merge_results, result_table
import cfg


def load(df):

    # generate dict name
    index_al = df[df['qm'] != 'random'].index[0]

    (
        dataset,
        dataset_type,
        qm,
        init_train_samples,
        add_train_samples,
        max_train_samples,
        train_paradigm,
        weight_init,
        training,
        nr_classes,
        classes,
    ) = df.loc[
        index_al,
        [
        'dataset',
        'dataset_type',
        'qm',
        'init_train_samples',
        'add_train_samples',
        'max_train_samples',
        'train_paradigm',
        'weight_init',
        'training',
        'nr_classes',
        'classes',
        ]
        ]

    filename = (
        f'metric-change_{dataset}_{dataset_type}_{qm}_{init_train_samples}_{add_train_samples}_'
        f'{max_train_samples}_{train_paradigm}_{weight_init}_{training}_{nr_classes}_{classes[0]}.pkl'
    )

    # check if file exists
    file_path = Path(cfg.path_exp, 'metric_changes', filename)
    file_path.parent.mkdir(parents=True, exist_ok=True)
    if os.path.exists(file_path):
        with open(Path(cfg.path_exp, 'metric_changes', filename), 'rb') as f:
            dict_performance_metric = pickle.load(f)
    else:
        print(f'Create {filename}')
        dict_performance_metric = create_dict(df)
        with open(file_path, 'wb') as f:
            pickle.dump(dict_performance_metric, f)

    return dict_performance_metric


def create_dict(df):
    # get indices for random and al
    index_random = df[df['qm'] == 'random'].index[0]
    index_al = df[df['qm'] != 'random'].index[0]

    # get nr training samples and performance
    nr_training_samples = df.loc[index_random, 'nr_training_samples']
    performance_random = df.loc[index_random, 'evaluation_f1_macro']
    performance_al = df.loc[index_al, 'evaluation_f1_macro']

    # assure equal length
    min_len = min(len(nr_training_samples), len(performance_random), len(performance_al))
    nr_training_samples = nr_training_samples[:min_len]
    performance_random = performance_random[:min_len]
    performance_al = performance_al[:min_len]

    # initialize dict
    result_nr_training_samples = []
    result_lcmean = []
    result_lcaulc = []
    result_lcaulc_norm = []
    result_cutpoints = []
    result_sfix = []
    result_s = []
    for index, value in enumerate(nr_training_samples):
        if index == 0:
            continue
        # cut data
        nr_training_samples_cut = nr_training_samples[:index+1]
        performance_random_cut = performance_random[:index+1]
        performance_al_cut = performance_al[:index+1]

        # compute LC_mean
        lc_mean = np.nanmean(performance_al_cut)

        # compute area under learning curve (AL and RAND)
        valid_indices = ~np.isnan(performance_al_cut)
        x_valid = nr_training_samples_cut[valid_indices]
        y_valid = performance_al_cut[valid_indices]
        lc_aulc = np.trapezoid(y_valid, x_valid)

        valid_indices = ~np.isnan(performance_random_cut)
        x_rand_valid = nr_training_samples_cut[valid_indices]
        y_rand_valid = performance_random_cut[valid_indices]
        lc_rand_aulc = np.trapezoid(y_rand_valid, x_rand_valid)

        aulc_normalized = lc_aulc / lc_rand_aulc

        # compute cutpoints
        valid_indices = ~(np.isnan(performance_al_cut) | np.isnan(performance_random_cut))
        y_valid = performance_al_cut[valid_indices]
        y_rand_valid = performance_random_cut[valid_indices]
        cut_points_p_al_better_rand = compute_cutpoints_p_value(y_valid, y_rand_valid)
        cut_points_p_rand_better_al = compute_cutpoints_p_value(y_rand_valid, y_valid)

        # compute sf
        df_for_sf = df.copy()
        df_for_sf['nr_training_samples'] = [nr_training_samples_cut] * len(df_for_sf)
        df_for_sf.at[index_al, 'evaluation_f1_macro'] = performance_al_cut
        df_for_sf.at[index_random, 'evaluation_f1_macro'] = performance_random_cut
        df_for_sf = result_table.get_sf_properties(df_for_sf, [index_random, index_al], a0_equal_0=False)

        # save metric results
        result_nr_training_samples.append(nr_training_samples_cut[-1])
        result_lcmean.append(lc_mean)
        result_lcaulc.append(lc_aulc)
        result_lcaulc_norm.append(aulc_normalized)
        result_cutpoints.append(cut_points_p_al_better_rand)
        result_sfix.append(df_for_sf.loc[index_al, 'sf_p1'])
        result_s.append(df_for_sf.loc[index_al, 'sf_best_p'])

    # results to dict
    result_dict = {'nr_training_samples': result_nr_training_samples,
                   'lc_mean': result_lcmean,
                   'lc_aulc': result_lcaulc,
                   'lc_aulc_norm': result_lcaulc_norm,
                   'cutpoints': result_cutpoints,
                   's_fix': result_sfix,
                   's': result_s}
    result_dict = {key: np.array(value) for key, value in result_dict.items()}
    return result_dict


def compute_cutpoints_p_value(y_alpha, y_beta):
    # cutpoint score
    c_score = y_alpha - y_beta
    # ranking
    ranking = 1 + np.argsort(np.argsort(c_score)[::-1])
    # ideal ranking
    ideal_ranking = np.sort(ranking)[::-1]
    # L value
    l_value = np.sum(ranking * ideal_ranking)
    # approximated p-value
    m = 1
    k = len(ideal_ranking)
    z = (12 * (l_value - 0.5) - 3 * m * k * (k + 1) ** 2) / (k * (k + 1) * np.sqrt(m * (k - 1)))
    p_approx = stats.norm.sf(z)
    return p_approx
