import os
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.optimize import curve_fit

import cfg


def create():
    # compute recent values
    datasets = [d for d in os.listdir(cfg.path_exp) if os.path.isdir(os.path.join(cfg.path_exp, d))]
    # iterate over datasets
    for dataset in datasets:
        path_dataset = Path(cfg.path_exp, dataset)
        dataset_types = [d for d in os.listdir(path_dataset) if os.path.isdir(os.path.join(path_dataset, d))]
        # iterate over dataset types
        for dataset_type in dataset_types:
            print(f'Merge Results: {dataset} | {dataset_type}')
            path_results = Path(cfg.path_exp, dataset, dataset_type, 'results')
            files = [f for f in os.listdir(path_results) if f.endswith('.pkl')]

            # init df
            with open(Path(path_results, files[0]), 'rb') as f:
                result = pickle.load(f)
                columns = ['dataset_type', 'train_paradigm', 'weight_init', 'training', 'qm'] + list(
                    result.__dict__.keys())
                df_new = pd.DataFrame(index=range(len(files)), columns=columns)

            # iterate over files, add them to df
            for index, file in enumerate(files):
                # open file
                with open(Path(path_results, file), 'rb') as f:
                    result = pickle.load(f)
                    df_new.loc[index, 'dataset_type'] = file.split('_')[1]
                    df_new.loc[index, 'train_paradigm'] = file.split('_')[6]
                    df_new.loc[index, 'weight_init'] = file.split('_')[7]
                    df_new.loc[index, 'training'] = file.split('_')[8]
                    for key, value in result.__dict__.items():
                        df_new.loc[index, key] = value

            # copy qm col
            df_new['qm'] = df_new['sampling_method']

            # compute the mean across random seeds
            list_merged_dicts = []
            unique_combinations = df_new[['dataset', 'dataset_type', 'init_train_samples', 'add_train_samples',
                                          'max_train_samples', 'train_paradigm', 'weight_init', 'training',
                                          'qm']].drop_duplicates()
            for _, row in unique_combinations.iterrows():
                df_iter = df_new[
                    (df_new['dataset'] == row['dataset']) &
                    (df_new['dataset_type'] == row['dataset_type']) &
                    (df_new['init_train_samples'] == row['init_train_samples']) &
                    (df_new['add_train_samples'] == row['add_train_samples']) &
                    (df_new['max_train_samples'] == row['max_train_samples']) &
                    (df_new['train_paradigm'] == row['train_paradigm']) &
                    (df_new['weight_init'] == row['weight_init']) &
                    (df_new['training'] == row['training']) &
                    (df_new['qm'] == row['qm'])
                    ].reset_index(drop=True)

                # generate a dict with mean/std of all values
                dict_row = {}
                for col in df_iter.columns:
                    dtype = type(df_iter.loc[0, col])
                    if col == 'random_seed':
                        dict_row[col] = f'merged_{len(df_iter)}'
                    elif col == 'selected_training_sample_indices_by_al':
                        dict_row[col] = None
                    elif dtype == str or dtype == int or col == 'classes':
                        dict_row[col] = df_iter.loc[0, col]
                    else:
                        col_array = np.array(df_iter[col].to_list())
                        dict_row[col] = np.nanmean(col_array, axis=0)
                        dict_row[col + '_std'] = np.nanstd(col_array, axis=0)
                        dict_row[col + '_sem'] = np.nanstd(col_array, axis=0) / len(df_iter)

                # save dict to list
                list_merged_dicts.append(dict_row)

            # create merged df
            df_merged = pd.DataFrame(list_merged_dicts)

            # save dfs
            df_new.to_pickle(Path(cfg.path_exp, dataset, dataset_type, 'results.pkl'))
            df_merged.to_pickle(Path(cfg.path_exp, dataset, dataset_type, 'results_merged.pkl'))

def approximate_learning_curve():
    # compute recent values
    datasets = [d for d in os.listdir(cfg.path_exp) if os.path.isdir(os.path.join(cfg.path_exp, d))]
    # iterate over datasets
    for dataset in datasets:
        path_dataset = Path(cfg.path_exp, dataset)
        dataset_types = [d for d in os.listdir(path_dataset) if os.path.isdir(os.path.join(path_dataset, d))]
        # iterate over dataset types
        for dataset_type in dataset_types:
            print(f'Merge Results: {dataset} | {dataset_type}')

            # load merged df
            path_merged_df = Path(cfg.path_exp, dataset, dataset_type, 'results_merged.pkl')
            df = pd.read_pickle(path_merged_df)

            # iterate over all datasets (defined by hashs)
            unique_combinations = df[['dataset', 'dataset_type', 'train_paradigm', 'weight_init',
                                      'training']].drop_duplicates()

            for index, row in unique_combinations.iterrows():
                indices = df[
                    (df['dataset'] == row['dataset']) &
                    (df['dataset_type'] == row['dataset_type']) &
                    (df['train_paradigm'] == row['train_paradigm']) &
                    (df['weight_init'] == row['weight_init']) &
                    (df['training'] == row['training'])
                    ].index

                print(f'Compute Speedup factor: {dataset} {dataset_type} | {index} / {len(unique_combinations)}')

                df = get_sf_properties(df, indices)

            # save df
            df.to_pickle(path_merged_df)


def get_sf_properties(df, indices, a0_equal_0=False):
    # compute and save a_inf and a_0_p1 and a_0_p2
    df = get_constants_a(df, indices)
    if a0_equal_0:
        df.loc[indices, ['a_0_p1', 'a_0_p2']] = 0

    # compute and save b_p1 and b_p2
    df = get_constants_b(df, indices)

    # compute the speedup factor
    df = get_sf(df, indices)

    # derive which function approximates the results best
    df = get_best_p_func(df, indices)

    return df


def get_constants_a(df, indices, tag='evaluation'):
    # init df values for a_inf, a_0_p1 and a_0_p2
    df.loc[indices, ['a_inf', 'a_0_p1', 'a_0_p2']] = np.nan

    # if for any method more than 50% of the performance could not be computed, discard the results
    if df.loc[indices, tag + '_f1_macro'].apply(lambda x: np.isnan(x).mean() > 0.5).any():
        return df

    # compute a_inf
    if df.loc[indices, tag + '_ceiling_f1_macro'].apply(lambda x: len(x) == 0).all():
        a_inf = np.nanmean(df.loc[indices, tag + '_f1_macro'].apply(lambda x: x[-1]))
    else:
        a_inf = np.nanmean(np.nanmean(df.loc[indices, tag + '_ceiling_f1_macro']))

    # compute a_0_p1 and a_0_p2
    min_performance = df.loc[indices, tag + '_f1_macro'].apply(lambda x: np.nanmin(x)).median()
    a_0_p1 = np.log(1 - min_performance/a_inf)
    a_0_p2 = np.log(a_inf/min_performance - 1)

    # save a_inf, a_0_p1 and a_0_p2
    df.loc[indices, ['a_inf', 'a_0_p1', 'a_0_p2']] = [a_inf, a_0_p1, a_0_p2]

    return df


def get_constants_b(df, indices, tag='evaluation'):
    # init df values for b_p1 and b_p2
    df.loc[indices, ['b_p1', 'b_p2']] = np.nan

    # fit p1/p2 if a_inf not nan and a_0_p1/a_0_p2 not nan
    for p_func in ['p1', 'p2']:
        if not df.loc[indices, ['a_inf', 'a_0_' + p_func]].isna().any().any():
            # because the curve fit sometimes does not work, do it iteratively
            for index in indices:
                df = compute_b_value(df, index, tag, p_func)
    return df


def compute_b_value(df, index, tag, p_func):
    def approx_func(x_func, b_qm):
        if p_func == 'p1':
            return a_inf * (1 - np.exp(a_0 - x_func / b_qm))
        elif p_func == 'p2':
            return a_inf * (1 / (1 + np.exp(a_0 - x_func / b_qm)))

    # get x data (nr_training_samples), normalize to first value = 0
    x = df.loc[index, 'nr_training_samples']
    x = x - x[0]

    # get y data (performance)
    y = df.loc[index, tag + '_f1_macro']

    # discard NaNs from x and y
    nan_indices = np.isnan(y)
    x = x[~nan_indices]
    y = y[~nan_indices]

    # read a_0 and a_inf
    a_0 = df.loc[index, 'a_0_' + p_func]
    a_inf = df.loc[index, 'a_inf']

    # fit parameter b
    initial_guesses = [100, 1000, 10, 1]
    for p0 in initial_guesses:
        try:
            parameters, covariance = curve_fit(approx_func, x, y, p0=p0)
            b_qm = parameters[0]
            df.loc[index, 'b_' + p_func] = b_qm
            break
        except Exception as e:
            print(f'{p_func}: Fit failed with p0={p0}: {e}')
    return df


def get_sf(df, indices):
    # init df values for b_p1 and b_p2
    df.loc[indices, ['sf_p1', 'sf_p2']] = np.nan

    # equal columns
    cols = ['dataset', 'dataset_type', 'init_train_samples', 'add_train_samples',
            'max_train_samples', 'train_paradigm', 'weight_init', 'training']

    # compute for all QMs the speedup factor
    for index, row in df.iterrows():
        mask = (df[cols] == row[cols]).all(axis=1) & (df['qm'] == 'random')

        # Get the index (or indices) where the condition is True
        index_rand = df[mask].index

        # compute sf
        if len(index_rand) == 1:
            for p_func in ['p1', 'p2']:
                if not df.loc[indices, ['a_inf', 'a_0_' + p_func]].isna().any().any():
                    # get b_p1_rand/b_p2_rand
                    b_rand = df.loc[index_rand[0], 'b_' + p_func]

                    # compute speedup factor
                    df.loc[index, 'sf_' + p_func] = df.loc[index, 'b_' + p_func] / b_rand

    return df


def get_best_p_func(df, indices, tag='evaluation'):
    # init df values for b_p1 and b_p2
    df.loc[indices, ['best_p', 'sf_best_p']] = np.nan

    # get the x-value (# added training samples) and y-values (performance), discard NaN values
    df_xy_all = df.loc[indices, ['nr_training_samples', tag + '_f1_macro']].copy()
    df_xy_all['nr_training_samples'] = df_xy_all['nr_training_samples'].apply(lambda x: x - x[0])
    df_xy_nonan = df_xy_all.apply(lambda row: [col[~np.isnan(row['evaluation_f1_macro'])] for col in row], axis=1)
    df_xy = pd.DataFrame(df_xy_nonan.tolist(), columns=df_xy_all.columns, index=df_xy_all.index)

    # get a_0, a_inf, b_p1, b_p2 cols
    columns_to_copy = ['a_0_p1', 'a_0_p2', 'a_inf', 'b_p1', 'b_p2']
    df_xy.loc[indices, columns_to_copy] = df.loc[indices, columns_to_copy]

    # get the approximation function for the dataset with the least mean squared error
    best_p = ''
    min_rmse = np.inf
    for p_func in ['p1', 'p2']:
        if not df.loc[indices, ['a_inf', 'a_0_' + p_func]].isna().any().any():
            # compute fitted curve (y-fit) for given x samples
            if p_func == 'p1':
                df_xy['fit_p1'] = df_xy.apply(lambda row: row['a_inf'] * (1 - np.exp(row['a_0_p1'] - row['nr_training_samples'] / row['b_p1'])), axis=1)
            elif p_func == 'p2':
                df_xy['fit_p2'] = df_xy.apply(lambda row: row['a_inf'] * 1 / (1 + np.exp(row['a_0_p2'] - row['nr_training_samples'] / row['b_p2'])), axis=1)

            # compute mean squared error for selected function
            df_xy['rmse_' + p_func] = df_xy.apply(lambda row: np.sqrt(np.mean((row[tag + '_f1_macro'] - row['fit_' + p_func]) ** 2)), axis=1)

            # get the mean of the mean squared error
            mean_rmse = df_xy['rmse_' + p_func].mean()

            # select function if mean rmse smaller
            if mean_rmse < min_rmse:
                best_p = p_func
                min_rmse = mean_rmse

    #####################
    """
    import matplotlib.pyplot as plt
    for index, row in df_xy.iterrows():
        x = row['nr_training_samples']
        y_true = row['evaluation_f1_macro']
        y_p1 = row['fit_p1']
        y_p2 = row['fit_p2']
        plt.figure()
        plt.scatter(x, y_true)
        plt.plot(x, y_p1)
        plt.plot(x, y_p2)
        plt.show()
    """
    #####################

    # select the sf values for the best_p
    df.loc[indices, 'best_p',] = best_p
    if best_p:
        df.loc[indices, 'sf_best_p'] = df.loc[indices, 'sf_' + best_p]
    return df


if __name__ == '__main__':
    create()
    approximate_learning_curve()