import os
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.optimize import curve_fit

import cfg


def run(exp):
    # define paths with result files and target df
    path_folder_df = Path(cfg.path_exp, exp[0], exp[1])
    path_results = Path(path_folder_df, 'results')
    # get a list of all files
    files = [f for f in os.listdir(path_results) if f.endswith('.pkl')]
    # init df
    with open(Path(path_results, files[0]), 'rb') as f:
        result = pickle.load(f)
        columns = ['hash'] + list(result.__dict__.keys())
        df = pd.DataFrame(index=range(len(files)), columns=columns)

    # iterate over files, add them to df
    for index, file in enumerate(files):
        # open file
        with open(Path(path_results, file), 'rb') as f:
            result = pickle.load(f)
            df.loc[index, 'hash'] = file.split('-')[2]
            for key, value in result.__dict__.items():
                df.loc[index, key] = value

        print(exp, index)

    # compute the mean across random seeds
    list_merged_dicts = []
    for hash_ in df['hash'].unique():
        for sampling_method in df['sampling_method'].unique():
            # grab rows with same hash and sampling method (different random seeds)
            df_iter = df[(df['hash'] == hash_) & (df['sampling_method'] == sampling_method)].reset_index(drop=True)
            if len(df_iter) == 0:
                continue
            # generate a dict with mean of all values
            dict_row = {}
            for col in df_iter.columns:
                dtype = type(df_iter.loc[0, col])
                if col == 'random_seed':
                    dict_row[col] = f'merged_{len(df_iter)}'
                elif col == 'selected_training_sample_indices_by_al':
                    dict_row[col] = None
                elif dtype == str or dtype == int or col == 'classes':
                    dict_row[col] = df_iter.loc[0, col]
                else:
                    col_array = np.array(df_iter[col].to_list())
                    dict_row[col] = np.nanmean(col_array, axis=0)
                    dict_row[col+'_std'] = np.nanstd(col_array, axis=0)
                    dict_row[col + '_sem'] = np.nanstd(col_array, axis=0) / len(df_iter)

            # save dict to list
            list_merged_dicts.append(dict_row)
    # create merged df
    df_merged = pd.DataFrame(list_merged_dicts)

    # save dfs
    df.to_pickle(Path(path_folder_df, 'results.pkl'))
    df_merged.to_pickle(Path(path_folder_df, 'results_merged.pkl'))


def approximate_learning_curve(exp):
    # load merged df
    path_merged_df = Path(cfg.path_exp, exp[0], exp[1], 'results_merged.pkl')
    df = pd.read_pickle(path_merged_df)

    # iterate over all datasets (defined by hashs)
    datasets = df['hash'].unique()
    for dataset in datasets:
        print(f'Compute Speedup factor: {exp[0]} {exp[1]} | {np.where(datasets == dataset)[0][0]} / {len(datasets)}')

        # get rows from the dataset (different QMs)
        indices = df[df['hash'] == dataset].index

        # compute and save a_inf and a_0_p1 and a_0_p2
        df = get_sf_properties(df, indices)

    df.to_pickle(path_merged_df)


def get_sf_properties(df, indices, a0_equal_0=False):
    # compute and save a_inf and a_0_p1 and a_0_p2
    df = get_constants_a(df, indices)
    if a0_equal_0:
        df.loc[indices, ['a_0_p1', 'a_0_p2']] = 0

    # compute and save b_p1 and b_p2
    df = get_constants_b(df, indices)

    # compute the speedup factor
    df = get_sf(df, indices)

    # derive which function approximates the results best
    df = get_best_p_func(df, indices)

    return df


def get_constants_a(df, indices, tag='evaluation'):
    # init df values for a_inf, a_0_p1 and a_0_p2
    df.loc[indices, ['a_inf', 'a_0_p1', 'a_0_p2']] = np.nan

    # if for any method more than 50% of the performance could not be computed, discard the results
    if df.loc[indices, tag + '_f1_macro'].apply(lambda x: np.isnan(x).mean() > 0.5).any():
        return df

    # compute a_inf
    if df.loc[indices, tag + '_ceiling_f1_macro'].apply(lambda x: len(x) == 0).all():
        a_inf = np.nanmean(df.loc[indices, tag + '_f1_macro'].apply(lambda x: x[-1]))
    else:
        a_inf = np.nanmean(np.nanmean(df.loc[indices, tag + '_ceiling_f1_macro']))

    # compute a_0_p1 and a_0_p2
    min_performance = df.loc[indices, tag + '_f1_macro'].apply(lambda x: np.nanmin(x)).median()
    a_0_p1 = np.log(1 - min_performance/a_inf)
    a_0_p2 = np.log(a_inf/min_performance - 1)

    # save a_inf, a_0_p1 and a_0_p2
    df.loc[indices, ['a_inf', 'a_0_p1', 'a_0_p2']] = [a_inf, a_0_p1, a_0_p2]

    return df


def get_constants_b(df, indices, tag='evaluation'):
    # init df values for b_p1 and b_p2
    df.loc[indices, ['b_p1', 'b_p2']] = np.nan

    # fit p1/p2 if a_inf not nan and a_0_p1/a_0_p2 not nan
    for p_func in ['p1', 'p2']:
        if not df.loc[indices, ['a_inf', 'a_0_' + p_func]].isna().any().any():
            # because the curve fit sometimes does not work, do it iteratively
            for index in indices:
                df = compute_b_value(df, index, tag, p_func)
    return df


def compute_b_value(df, index, tag, p_func):
    def approx_func(x_func, b_qm):
        if p_func == 'p1':
            return a_inf * (1 - np.exp(a_0 - x_func / b_qm))
        elif p_func == 'p2':
            return a_inf * (1 / (1 + np.exp(a_0 - x_func / b_qm)))

    # get x data (nr_training_samples), normalize to first value = 0
    x = df.loc[index, 'nr_training_samples']
    x = x - x[0]

    # get y data (performance)
    y = df.loc[index, tag + '_f1_macro']

    # discard NaNs from x and y
    nan_indices = np.isnan(y)
    x = x[~nan_indices]
    y = y[~nan_indices]

    # read a_0 and a_inf
    a_0 = df.loc[index, 'a_0_' + p_func]
    a_inf = df.loc[index, 'a_inf']

    # fit parameter b
    initial_guesses = [100, 1000, 10, 1]
    for p0 in initial_guesses:
        try:
            parameters, covariance = curve_fit(approx_func, x, y, p0=p0)
            b_qm = parameters[0]
            df.loc[index, 'b_' + p_func] = b_qm
            break
        except Exception as e:
            print(f'{p_func}: Fit failed with p0={p0}: {e}')
    return df


def get_sf(df, indices):
    # init df values for b_p1 and b_p2
    df.loc[indices, ['sf_p1', 'sf_p2']] = np.nan

    # get index from the dataset with random
    index_rand = df.loc[indices, 'sampling_method'].apply(lambda x: x == 'random').loc[lambda x: x].index

    # compute sf
    if len(index_rand) == 1:
        for p_func in ['p1', 'p2']:
            if not df.loc[indices, ['a_inf', 'a_0_' + p_func]].isna().any().any():
                # get b_p1_rand/b_p2_rand
                b_rand = df.loc[index_rand[0], 'b_' + p_func]

                # compute speedup factor
                df.loc[indices, 'sf_' + p_func] = df.loc[indices, 'b_' + p_func] / b_rand

    return df


def get_best_p_func(df, indices, tag='evaluation'):
    # init df values for b_p1 and b_p2
    df.loc[indices, ['best_p', 'sf_best_p']] = np.nan

    # get the x-value (# added training samples) and y-values (performance), discard NaN values
    df_xy_all = df.loc[indices, ['nr_training_samples', tag + '_f1_macro']].copy()
    df_xy_all['nr_training_samples'] = df_xy_all['nr_training_samples'].apply(lambda x: x - x[0])
    df_xy_nonan = df_xy_all.apply(lambda row: [col[~np.isnan(row['evaluation_f1_macro'])] for col in row], axis=1)
    df_xy = pd.DataFrame(df_xy_nonan.tolist(), columns=df_xy_all.columns, index=df_xy_all.index)

    # get a_0, a_inf, b_p1, b_p2 cols
    columns_to_copy = ['a_0_p1', 'a_0_p2', 'a_inf', 'b_p1', 'b_p2']
    df_xy.loc[indices, columns_to_copy] = df.loc[indices, columns_to_copy]

    # get the approximation function for the dataset with the least mean squared error
    best_p = ''
    min_rmse = np.inf
    for p_func in ['p1', 'p2']:
        if not df.loc[indices, ['a_inf', 'a_0_' + p_func]].isna().any().any():
            # compute fitted curve (y-fit) for given x samples
            if p_func == 'p1':
                df_xy['fit_p1'] = df_xy.apply(lambda row: row['a_inf'] * (1 - np.exp(row['a_0_p1'] - row['nr_training_samples'] / row['b_p1'])), axis=1)
            elif p_func == 'p2':
                df_xy['fit_p2'] = df_xy.apply(lambda row: row['a_inf'] * 1 / (1 + np.exp(row['a_0_p2'] - row['nr_training_samples'] / row['b_p2'])), axis=1)

            # compute mean squared error for selected function
            df_xy['rmse_' + p_func] = df_xy.apply(lambda row: np.sqrt(np.mean((row[tag + '_f1_macro'] - row['fit_' + p_func]) ** 2)), axis=1)

            # get the mean of the mean squared error
            mean_rmse = df_xy['rmse_' + p_func].mean()

            # select function if mean rmse smaller
            if mean_rmse < min_rmse:
                best_p = p_func
                min_rmse = mean_rmse

    #####################
    """
    import matplotlib.pyplot as plt
    for index, row in df_xy.iterrows():
        x = row['nr_training_samples']
        y_true = row['evaluation_f1_macro']
        y_p1 = row['fit_p1']
        y_p2 = row['fit_p2']
        plt.figure()
        plt.scatter(x, y_true)
        plt.plot(x, y_p1)
        plt.plot(x, y_p2)
        plt.show()
    """
    #####################

    # select the sf values for the best_p
    df.loc[indices, 'best_p',] = best_p
    if best_p:
        df.loc[indices, 'sf_best_p'] = df.loc[indices, 'sf_' + best_p]
    return df




#if __name__ == '__main__':
#    run(['multi_label', 'reuters'])