from collections.abc import Iterable
import os
from os.path import join

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

from eeg_augment.utils import load_obj, flexible_float


def pad_list_of_dfs(df_list, max_len):
    for i, df_fold in enumerate(df_list):
        padded_arr = np.pad(
            df_fold.values,
            ((0, max_len - df_fold.shape[0]), (0, 0)),
            'edge'
        )
        padded_arr[:, 0] = np.arange(1, max_len+1)
        df_list[i] = pd.DataFrame(
            padded_arr,
            columns=df_fold.columns,
        ).astype(df_fold.dtypes)
    return df_list


def make_df_out_of_crossval_results(
    experiement_folder,
    to_extract=None,
    pad=False,
    only_best=False
):
    """Fetches histories from each fold from training directory and returns
    DataFrame containing desired metrics

    Parameters
    ----------
    experiement_folder : str
        Path to experiment directory, containing the following structure:
        transform > fold > training size.
    to_extract : list, optional
        List of metrics to extract from histories. If omitted, will extract
        train_bal_acc, valid_bal_acc, valid_loss and train_loss.
    pad : bool, optional
        Whether to pad dataframe with 'edge' strategy so that all folds have
        the same number of rows (epochs). Useful for plotting when
        EarlyStopping is used (hence leading to different number of epochs
        per fold). By default False.
    only_best : bool, optional
        Whether to fetch metrics for all epochs of every fold and training
        subset (=False) or only fetch the best metrics for each fold and
        subset. By default False.

    Returns
    -------
    pandas.DataFrame
    """
    if to_extract is None:
        to_extract = [
            'epoch',
            'valid_bal_acc',
            'train_bal_acc',
            'valid_loss',
            'train_loss'
        ]
    elif isinstance(to_extract, list) and len(to_extract) > 0:
        if 'epoch' not in to_extract:
            to_extract = ['epoch'] + to_extract
    else:
        raise ValueError(
            "to_extract needs to be a non-empty list of strings or `None` !"
        )
    if only_best:
        to_extract.append('valid_bal_acc_best')

    df_list = list()
    max_len = 0
    for transform in os.listdir(experiement_folder):
        transform_dir = join(experiement_folder, transform)
        if os.path.isdir(transform_dir) and transform != "saved_splits":
            for fold in os.listdir(transform_dir):
                fold_dir = join(transform_dir, fold)
                if os.path.isdir(fold_dir):
                    for data_size in os.listdir(fold_dir):
                        data_regime_dir = join(fold_dir, data_size)
                        if os.path.isdir(data_regime_dir):
                            tr_history_path = join(
                                data_regime_dir,
                                'train_end_history.json'
                            )
                            sub_df = pd.read_json(
                                tr_history_path
                            ).loc[:, to_extract]
                            if only_best:
                                sub_df = sub_df[
                                    sub_df['valid_bal_acc_best']
                                ].tail(1)
                                sub_df.drop(
                                    'valid_bal_acc_best',
                                    axis=1,
                                    inplace=True
                                )
                            elif sub_df.shape[0] > max_len:
                                max_len = sub_df.shape[0]
                            sub_df['fold'] = fold
                            sub_df['subset_ratio'] = data_size.split('_')[1]
                            sub_df['transform'] = transform
                            df_list.append(sub_df)
    if pad:
        df_list = pad_list_of_dfs(df_list, max_len)
    return pd.concat(df_list, ignore_index=True)


def slice_training_results(df, transform=None, subset_ratio=None, fold=None):
    """Slice broad results dataframe per transform, fold or training size

    Parameters
    ----------
    df : pandas.DataFrame
        Results broad dataframe.
    transform : str | None, optional
        Transform to fetch, by default None.
    subset_ratio : float | None, optional
        Ratio of full training set to fetch. By default None.
    fold : str | None, optional
        Fold to fetch, by default None.

    Returns
    -------
    pandas.DataFrame
    """
    if transform is None and subset_ratio is None and fold is None:
        return df
    if transform is not None:
        df = df[df['transform'] == transform]
    if fold is not None:
        df = df[df['fold'] == fold]
    if subset_ratio is not None:
        df = df.query("subset_ratio == @subset_ratio")
    return df.reset_index()


def broad_to_long_table(df):
    """Makes a long DataFrame out of the make_df_out_of_crossval_results, for
    seabron plotting.

    Parameters
    ----------
    df : pandas.DataFrame
        Broad dataframe of merged training histories (output from
        make_df_out_of_crossval_results).

    Returns
    -------
    pandas.DataFrame
    """
    long_df = []
    for r in range(df.shape[0]):
        sub_df = []
        row = df.loc[r, :]
        entry = {}
        for c in row.index:
            if 'loss' not in c and 'acc' not in c:
                entry[c] = row[c]
        for c in row.index:
            if 'loss' in c or 'acc' in c:
                entry_metric = entry.copy()
                entry_metric.update({
                    'value': row[c],
                    'metric': '_'.join(c.split('_')[1:]),
                    'set': c.split('_')[0]
                })
                sub_df.append(entry_metric)
        long_df.append(pd.DataFrame(sub_df))
    return pd.concat(long_df, ignore_index=True)


def plot_learning_curve_from_long_df(
    long_df,
    ax_list=None,
    use_style=False,
    label=None
):
    """Plots learning curve (perf vs epoch) of a cross-validated training using
    a long histories DataFrame

    Parameters
    ----------
    long_df : pandas.DataFrame
        Output of broad_to_long_table.
    ax_list : list | None, optional
        When given, axis will be used to plot. Useful when several trainings
        are to be plotted on same figure/axis. By default None.
    use_style : bool, optional
        Whether to use linestyle (instead of colors) to differentiate
        validation and training curves. Useful when plotting several trainings
        on the same axis. By default False.
    label : str | None, optional
        Label to give to the curves (instead of valid/train). Useful when
        plotting several trainings on the same axis. By default None.
    """
    sns.set_theme(style="whitegrid", context="notebook")

    if ax_list is None:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    elif not isinstance(ax_list, list):
        print('ax_list should be a list of axis !')
    elif len(ax_list) != 2:
        print('ax_list should have 2 elements !')
    else:
        ax1, ax2 = ax_list

    if not use_style:
        sns.lineplot(
            x="epoch", y="value",
            hue="set",
            data=long_df[long_df['metric'] == "bal_acc"],
            ax=ax1, label=label
        )

        sns.lineplot(
            x="epoch", y="value",
            hue="set",
            data=long_df[long_df['metric'] == "loss"],
            ax=ax2, label=label
        )
    else:
        sns.lineplot(
            x="epoch", y="value",
            style="set",
            data=long_df[long_df['metric'] == "bal_acc"],
            ax=ax1, label=label
        )

        sns.lineplot(
            x="epoch", y="value",
            style="set",
            data=long_df[long_df['metric'] == "loss"],
            ax=ax2, label=label
        )
    ax1.set_ylabel("Balanced accuracy")
    ax2.set_ylabel("Balanced cross-entropy loss")


def plot_learning_curve(
    experiment_folder,
    metrics=None,
    plotting_kwargs=None
):
    """Fetches histories from each fold from training directory and plots
    learning curves fo cross-validated training.

    Parameters
    ----------
    experiement_folder : str
        Path to experiment directory, containing the following structure:
        transform > fold > training size.
    fetching_kwargs : dict, optional
        Arguments to be passed to make_df_out_of_crossval_results. By default
        None.
    plotting_kwargs : dict, optional
        Arguments to be passed to plot_learning_curve_from_long_df. By default
        None.
    """
    if fetching_kwargs is None:
        fetching_kwargs = {}
    if plotting_kwargs is None:
        plotting_kwargs = {}
    results_df = make_df_out_of_crossval_results(
        experiment_folder,
        pad=True,
        **fetching_kwargs
    )
    plot_learning_curve_from_long_df(
        broad_to_long_table(results_df),
        **plotting_kwargs
    )


def format_handles_and_labels(handles, labels):
    """Post-process automatic labels generated by plot_learning_curve

    When calling plot_learning_curve many times on same axis (for comparing
    several trainings for example), this function will remove the
    duplicated train/valid labels.

    Parameters
    ----------
    handles : list
        Handles from axes to update.
    labels : list
        Labels from axes to update.

    Returns
    -------
    list
        New handles.

    list
        New labels.
    """
    new_handles = []
    new_labels = []
    styles_handles = set()
    styles_labels = set()
    for h, l in zip(handles, labels):
        if l not in ['train', 'valid']:
            new_labels.append(l)
            new_handles.append(h)
        else:
            styles_handles.add(h)
            styles_labels.add(l)
    new_labels += list(styles_labels)
    new_handles += list(styles_handles)
    return new_handles, new_labels


def _extract_test_results(test_res_path):
    df_transform = pd.read_pickle(
        test_res_path
    ).drop('test_confusion_matrix', axis=1)
    df_transform = df_transform.melt(
        id_vars=['fold', 'n_fold', 'subset_ratio'],
        value_vars=[
            'test_loss',
            'test_cohen_kappa_score',
            'test_bal_acc'
        ],
        var_name='metric'
    )
    return df_transform


def get_test_results_df(exp_path):
    """Loads cross-validated results on test set and stored them on a single
    dataframe

    Parameters
    ----------
    exp_path : str
        Path to experiment directory used by train.CrossvalModel (where the
        results are stored in different folders for each transform).

    Returns
    -------
    pandas.DataFrame
    """
    df = list()
    for transform in os.listdir(exp_path):
        transform_path = join(exp_path, transform)
        if os.path.isdir(transform_path) and transform != "saved_splits":
            transform_test_res_path = join(
                transform_path,
                'test_crossval_results.pkl'
            )
            if os.path.isfile(transform_test_res_path): 
                df_transform = _extract_test_results(transform_test_res_path)
                df_transform['transform'] = transform
                df.append(df_transform)
            else:
                for setting in os.listdir(transform_path):
                    setting_path = join(transform_path, setting)
                    setting_test_res_path = join(
                        setting_path,
                        'test_crossval_results.pkl'
                    )
                    if os.path.isfile(setting_test_res_path):
                        df_setting = _extract_test_results(setting_test_res_path)
                        proba, mag = map(flexible_float, setting.split('-'))
                        df_setting['transform'] = transform
                        df_setting['probability'] = proba
                        df_setting['magnitude'] = mag
                        df.append(df_setting)
    return pd.concat(df, ignore_index=True)


def plot_data_learning_curve(
    results_df,
    relative=True,
    aggregate_folds=True,
    ref_aug="no-aug",
    log_axis=True,
    metric="test_bal_acc",
    ax_to_use=None,
):
    """Plots performance vs the size of the training set

    Parameters
    ----------
    results_df : pandas.DataFrame
        Broad table containing the results for every fold and every subset
        (obtained) with make_df_out_of_crossval_results or get_test_results_df.
    relative : bool, optional
        Whether to compute the performances fold-wise relative to a reference
        transform. By default True.
    aggregate_folds : bool, optional
        Whether to plot the average across folds and shaded std (True) or to
        plot one separate curve per fold. By default True.
    ref_aug : str, optional
        Reference transformation to use when `relative` is set to True. By
        default "no-aug".
    log_axis : bool, optional
        Whether to plot with training sizes in a log2 axis. By default True.
    metric : str, optional
        What performance metric to plot. By default "test_bal_acc".
    """
    # sns.set_theme(style="whitegrid", context="notebook")

    transforms = np.unique(results_df['transform'])
    if relative:
        transforms = np.delete(transforms, np.where(transforms == 'no-aug')[0])

    if aggregate_folds:
        n_plots = 1
    else:
        n_plots = len(transforms)

    if ax_to_use is None:
        fig, ax_list = plt.subplots(
            n_plots,
            1,
            figsize=(15, 5 * n_plots)
        )

    if not aggregate_folds:
        for i, transform in enumerate(transforms):
            if isinstance(ax_list, Iterable):
                ax = ax_list[i]
            else:
                ax = ax_list
            df = results_df[results_df["metric"] == metric]
            df_rel = df[df["transform"] == transform].copy()
            if relative:
                ref = df.loc[df["transform"] == ref_aug, "value"].values
                aug = df.loc[df["transform"] == transform, "value"].values
                df_rel["value"] = (aug - ref) / ref

                sns.lineplot(
                    data=df_rel,
                    y="value", x="subset_ratio", hue="fold",
                    ax=ax
                )
                ax.hlines(0., *ax.get_xlim(), linestyle='--', color='k')
                ax.set_ylabel(f"{metric} improvement")
            else:
                sns.lineplot(
                    data=df_rel,
                    y="value", x="subset_ratio", hue="fold",
                    ax=ax
                )
                ax.set_ylabel(metric)
            ax.set_title(f"{transform}")
            if log_axis:
                ax.set_xscale("log", base=2)
    else:
        df = results_df[results_df["metric"] == metric]
        ax = ax_to_use if ax_to_use is not None else ax_list
        if relative:
            df_rel = df[df["transform"] != ref_aug].copy()
            ref = df.loc[df["transform"] == ref_aug, "value"].values
            for transform in transforms:
                df_rel.loc[df["transform"] == transform, "value"] -= ref
                df_rel.loc[df["transform"] == transform, "value"] /= ref
            sns.lineplot(
                data=df_rel,
                y="value", x="subset_ratio", hue="transform",
                ax=ax
            )
            ax.hlines(0., *ax.get_xlim(), linestyle='--', color='k')
            ax.set_ylabel(f"{metric} improvement")
        else:
            sns.lineplot(
                data=df,
                y="value", x="subset_ratio", hue="transform",
                ax=ax
            )
            ax.set_ylabel(metric)

        if log_axis:
            ax.set_xscale("log", base=2)
        ax.set_xlabel("Training set fraction")
        plt.legend(loc=0)

    plt.tight_layout()


def _make_scores_relative_to_first_value(df_rel, setting_to_compare, setting_ref=0.0):
    new_df = df_rel.copy()
    ref = df_rel.query(f"{setting_to_compare} == @setting_ref").value.values
    for setting in df_rel[setting_to_compare].unique():
        aug = df_rel.query(f"{setting_to_compare} == @setting").value.values
        new_df.loc[new_df[setting_to_compare] == setting, "value"] = (aug - ref) / ref
    return new_df


def plot_setting_learning_curve(
    results_df,
    relative=True,
    aggregate_folds=True,
    subset_ratio=None,
    setting_to_compare="magnitude",
    log_axis=False,
    metric="test_bal_acc",
    ax_to_use=None,
):
    """Plots performance vs the size of the training set

    Parameters
    ----------
    results_df : pandas.DataFrame
        Broad table containing the results for every fold and every subset
        (obtained) with make_df_out_of_crossval_results or get_test_results_df.
    relative : 
    aggregate_folds : bool, optional
        Whether to plot the average across folds and shaded std (True) or to
        plot one separate curve per fold. By default True.
    subset_ratio :
    setting_to_compare :
    log_axis : bool, optional
        Whether to plot with training sizes in a log2 axis. By default True.
    metric : str, optional
        What performance metric to plot. By default "test_bal_acc".
    """
    # sns.set_theme(style="whitegrid", context="notebook")

    if subset_ratio is not None:
        results_df = results_df.query("subset_ratio == @subset_ratio")

    transforms = np.unique(results_df['transform'])
    assert setting_to_compare in ['probability', 'magnitude'],\
        "setting_to_compare can be either magnitude or probability"
    setting_vals = np.unique(results_df[setting_to_compare])

    if aggregate_folds:
        n_plots = 1
    else:
        n_plots = len(transforms)

    fig, ax_list = plt.subplots(
        n_plots,
        1,
        figsize=(15, 5 * n_plots)
    )

    if not aggregate_folds:
        for i, transform in enumerate(transforms):
            if isinstance(ax_list, Iterable):
                ax = ax_list[i]
            else:
                ax = ax_list
            df = results_df[results_df["metric"] == metric]
            df_rel = df[df["transform"] == transform].copy()
            if relative:
                sns.lineplot(
                    data=_make_scores_relative_to_first_value(df_rel, setting_to_compare),
                    y="value", x=setting_to_compare, hue="fold",
                    ax=ax
                )
                ax.hlines(0., *ax.get_xlim(), linestyle='--', color='k')
                ax.set_ylabel(f"{metric} improvement")
            else:
                sns.lineplot(
                    data=df_rel,
                    y="value", x=setting_to_compare, hue="fold",
                    ax=ax
                )
                ax.set_ylabel(metric)
            ax.set_title(f"{transform}")
            if log_axis:
                ax.set_xscale("log", base=2)
    else:
        df = results_df[results_df["metric"] == metric]
        ax = ax_to_use if ax_to_use is not None else ax_list
        if relative:
            sns.lineplot(
                data=_make_scores_relative_to_first_value(df, setting_to_compare),
                y="value", x=setting_to_compare, hue="transform",
                ax=ax
            )
            ax.hlines(0., *ax_list.get_xlim(), linestyle='--', color='k')
            ax.set_ylabel(f"{metric} improvement")
        else:
            sns.lineplot(
                data=df,
                y="value", x=setting_to_compare, hue="transform",
                ax=ax
            )
            ax.set_ylabel(metric)

        if log_axis:
            ax.set_xscale("log", base=2)
        ax.set_xlabel(setting_to_compare.title())
        plt.legend(loc=0)

    plt.tight_layout()


def get_sanity_check_results(load_path, pad=False):
    results = load_obj(load_path)

    dfs = list()
    for model in results[0]:
        df_list = list()
        max_len = 0
        for k, histories in enumerate(results, start=1):
            df_fold = histories[model]
            df_fold['fold'] = k
            df_fold['transform'] = model
            ref_loss = histories['reference']['valid_loss']
            df_fold['val_loss_(aug-ref)/ref'] = (
                df_fold['valid_loss'] - ref_loss
            ) / ref_loss
            if df_fold.shape[0] > max_len:
                max_len = df_fold.shape[0]
            df_list.append(df_fold)
        if pad:
            df_list = pad_list_of_dfs(df_list, max_len)
        transform_df = pd.concat(df_list, ignore_index=True)
        dfs.append(transform_df)
    df = pd.concat(dfs, ignore_index=True)
    return df


def plot_sanity_check_curve(load_path, relative=True, save_path=None):
    df = get_sanity_check_results(load_path, pad=True)

    sns.set_theme(style="whitegrid", context="notebook")
    fig, ax = plt.subplots(figsize=(15, 5))

    if relative:
        sns.lineplot(
            x="epoch", y='val_loss_(aug-ref)/ref',
            data=df[df['transform'] == 'augment'],
            ax=ax
        )
        plt.hlines(0., *ax.get_xlim(), colors='r', linestyles='--')
    else:
        sns.lineplot(
            x="epoch", y='valid_loss',
            hue="transform",
            data=df,
            ax=ax
        )
    if save_path is not None:
        os.makedirs(os.path.os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
