import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import re


def process_list_params(list_params):
    if isinstance(list_params, list) and len(list_params) > 0:
        return '_'.join([str(i) for i in list_params])
    else:
        return '0'


def get_elements(my_list, my_and_filter, my_not_filter=[]):
    output = []

    for e in my_list:
        add = True
        for AND_F in my_and_filter:
            if AND_F not in e:
                add = False
                break

        for NOT_F in my_not_filter:
            if NOT_F in e:
                add = False
                break

        if add: output.append(e)

    return output


def box_plots(df_i, cols_metrics, col, x, hue, model_name, dataset_name, showfliers=False, order=None, show=False,
              folder='images', ylim=None):
    import seaborn as sns
    sns.set_context("paper", rc={"font.size": 16,
                                 "axes.titlesize": 16,
                                 "axes.labelsize": 16,
                                 'legend.fontsize': 16.0})

    sns.set_style("whitegrid")

    for idx_m, m in enumerate(cols_metrics):  # ['test_observation/mmd1']:
        n_cols = len(df_i[col].unique()) if col is not None else 1
        fig, axes = plt.subplots(1, n_cols, figsize=(n_cols * 5, 5))

        if col is not None:
            for idx, (label_g, df_g) in enumerate(df_i.groupby(col, dropna=False)):
                ax = axes[idx] if n_cols > 1 else axes
                if ylim[idx_m] is not None:
                    ax.set_ylim(ylim[idx_m])
                if hue is not None:
                    hue_order = sorted(df_i[hue].unique(),
                                       # key=len,
                                       reverse=False)
                else:
                    hue_order = None
                g = sns.boxplot(x=x, y=m, hue=hue,
                                order=order,
                                data=df_g, showfliers=showfliers,
                                hue_order=hue_order,
                                ax=ax)  # showfliers do not show outliers
                ax.set_title(label_g)
                if idx == (n_cols - 1):
                    ax.legend(bbox_to_anchor=(1.01, 1), borderaxespad=0)
                elif hue is not None:
                    ax.get_legend().remove()
        else:
            ax = axes
            if hue is not None:
                hue_order = sorted(df_i[hue].unique(),
                                   # key=len,
                                   reverse=False)
            else:
                hue_order = None
            g = sns.boxplot(x=x, y=m, hue=hue,
                            order=order,
                            data=df_i, showfliers=showfliers,
                            hue_order=hue_order,
                            ax=ax)  # showfliers do not show outliers
            # ax.legend(bbox_to_anchor=(1.01, 1), borderaxespad=0)

            # g.add_legend(title=hue, bbox_to_anchor=(-0.2, 1))
        # plt.tight_layout()
        plt.tight_layout()
        plt.savefig(f"images/{folder}/{dataset_name}_{model_name}_{m.replace('/', '_')}_catplot.jpg")
        if show:
            plt.show()
        else:
            plt.close('all')

def plot_cf_line(my_df, metrics, hue, col, x_label, y_label, name, dataset_name, id_ticks=2, show=False,
                 folder='images'):
    n_hue = len(my_df[hue].unique())
    n_cols = len(my_df[col].unique())
    fig, axes = plt.subplots(1, n_cols, figsize=(n_cols * 5, 5))

    for idx_c, (label_c, df_col) in enumerate(my_df.groupby(col)):
        ax = axes[idx_c] if n_cols > 1 else axes

        for idx, (label, df_g) in enumerate(df_col.groupby(hue)):
            df_tmp = pd.DataFrame()
            df_tmp2 = df_g[metrics].mean(0).copy()
            df_s = df_g[metrics].stack()
            df_tmp[x_label] = df_s.index.get_level_values(1)
            df_tmp[y_label] = df_s.values
            df_tmp[x_label] = df_tmp[x_label].replace({i: f"{i.split('_')[id_ticks]}" for i in df_tmp2.index})
            _ = sns.lineplot(data=df_tmp, x=x_label, y=y_label, label=label, ax=ax)
        _ = ax.set_xlabel(x_label)
        _ = ax.set_ylabel(y_label)
        if n_hue > 1 and idx_c == (n_cols - 1):
            ax.legend(bbox_to_anchor=(1.01, 1), borderaxespad=0)
        else:
            ax.get_legend().remove()

        ax.set_title(label_c)

    tmp = re.sub(r"[^a-zA-Z0-9]", "", x_label)
    tmp_y = re.sub(r"[^a-zA-Z0-9]", "", y_label)
    plt.tight_layout()
    plt.savefig(f"images/{folder}/{dataset_name}_{name}_{tmp}_{tmp_y}_lineplot.jpg")
    if show:
        plt.show()
    else:
        plt.close('all')


def get_unique_parameteres(columns_list, df_i, type_list=['model']):
    for c in get_elements(columns_list, type_list):
        if len(df_i[c].unique()) == 1: continue
        print(f"{c}")
        for i, u in enumerate(df_i[c].unique()):
            print(f"\t[{i}] {u}")


def get_best_models(df, objective_metric,  cv_dict, columns_list):
    print('\n\nComputing best configurations for each model and SEM:')



    best_models_list = []
    for dataset_name, df_dataset in df.groupby('dataset_name'):
        for m_name, df_m in df_dataset.groupby('model_name'):
            print('--------')
            for d_name, df_md in df_m.groupby('dataset_params_equations_type'):
                print(f'{dataset_name} : {m_name} : {d_name}')
                with open('best_models.txt', 'a') as f:
                    f.write(f'{dataset_name} : {m_name} : {d_name}\n')
                df_md_g = df_md.groupby(cv_dict[m_name], dropna=False).agg(['mean', 'std'])[objective_metric[m_name]]
                if len(df_md_g) > 1:
                    cond = df_md_g['std'] < 100000
                else:
                    cond = df_md_g['std'] < 100000

                best_config = df_md_g[cond]['mean'].idxmax()

                df_best_md = df_md.copy()

                for k, v in zip(cv_dict[m_name], best_config):
                    with open('best_models.txt', 'a') as f:
                        f.write(f'\t{k}: {v}\n')
                    print(f'\t{k}: {v}')
                    df_best_md = df_best_md[df_best_md[k] == v]

                print(f"Num of entries: {len(df_best_md)}")
                with open('best_models.txt', 'a') as f:
                    f_name = df_best_md.loc[df_best_md[objective_metric[m_name]].idxmax()]['json_filename']
                    f.write(f'\t{f_name}\n')
                print(df_best_md.loc[df_best_md[objective_metric[m_name]].idxmax()]['json_filename'])
                get_unique_parameteres(columns_list,
                                       df_i=df_best_md,
                                       type_list=['model'])

                my_mean, my_std = df_md_g[cond].loc[best_config]
                print(f"{objective_metric[m_name]}: {my_mean:.3f} +- {my_std:.3f}\n")
                best_models_list.append(df_best_md)

    df_best = pd.concat(best_models_list)

    print('\n\nModels we are comparing:')

    for m in df_best['model_name'].unique():
        print(f"\t{m}")


    return df_best
