import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import glob
import os



if __name__ == '__main__':

    plt.rcParams.update({'text.usetex': True})
    sns.set_style('white')

    dir = 'mvn/out/'
    in_files = ['gamma_2.csv', 'gamma_1.csv', 'gamma_02.csv']
    out_files_template = ['{}-{}-gamma-above-1.png', '{}-{}-gamma-1.png', '{}-{}-gamma-below-1.png']



    legend_order = [0, 1, 2, 3]

    for y_axis_label in ['DAMV', 'DASME']:
        for scaling in ['weight', 'bandwidth']:

            out_files = [f.format(y_axis_label.lower(), scaling) for f in out_files_template]

            if scaling == 'weight':
                method_mapping = {
                    '$k_{\\rm{RBF}}$': '$k_2 = k_1$',
                    '$\\sqrt{d} \\cdot k_{\\rm{RBF}}$': '$k_2 = \sqrt{d} \cdot k_1$',
                    '$\\log(d) \\cdot k_{\\rm{RBF}}$': '$k_2 = \log(d) \cdot k_1$'
                }
            elif scaling == 'bandwidth':
                method_mapping = {
                    '$k_{\\rm{RBF}}$': '$h_2 = h_1$',
                    '$k_{\\rm{RBF}}(\\cdot,\\cdot;\\sqrt{d})$': '$h_2 = \sqrt{d} \cdot h_1$',
                    '$k_{\\rm{RBF}}(\\cdot,\\cdot;\\log(d))$': '$h_2 = \log(d) \cdot h_1$'
                }
            
            if y_axis_label == 'DAMV':
                field = 'var_est'
                ground_truth_field = 'var'
            elif y_axis_label == 'DASME':
                field = 'abs_mean_error'
                ground_truth_field = 'zero'

            legend = [False, False, True]

            fig_ax_list_damv = [plt.subplots(figsize=(4, 2.5)) for _ in range(3)]


            for i, f in enumerate(in_files):
                
                fig_ax_list_damv[i][1].clear()
                
                print(os.path.join(dir, f))
                df = pd.read_csv(os.path.join(dir, f)).iloc[:,1:]
                df = df[df.k.isin(method_mapping.keys())]
                df['abs_mean_error'] = (df['mean']-df['mean_est'])**2

                df_agg = df.\
                    groupby(['d', 'exp', 'k', 'k_index']).\
                    agg(
                        var=('var', 'mean'),
                        var_est=('var_est', 'mean'),
                        mean=('mean', 'mean'),
                        mean_est=('mean_est', 'mean'),
                        abs_mean_error=('abs_mean_error', 'mean')
                    ).\
                    reset_index()
                df_agg['mean_error'] = df_agg['mean'] - df_agg['mean_est']
                df_agg['zero'] = 0

                df_agg['method_latex'] = df_agg.k.replace(method_mapping)

                if field=='var_est':
                    sns.lineplot(
                        data=df_agg,
                        x='d',
                        y=ground_truth_field,
                        color='black',
                        linestyle='--',
                        ax=fig_ax_list_damv[i][1],
                        label='ground truth',
                        legend=legend[i]
                    )

                sns.lineplot(
                    data=df_agg,
                    x='d',
                    y=field,
                    hue='method_latex',
                    hue_order=method_mapping.values(),
                    ci=field=='var_est',
                    ax=fig_ax_list_damv[i][1],
                    legend=legend[i]
                )

                fig_ax_list_damv[i][1].set_xlabel('Dimension')
                fig_ax_list_damv[i][1].set_ylabel(y_axis_label)

            # handles, labels = plt.gca().get_legend_handles_labels()
            # plt.legend([handles[j] for j in legend_order], [labels[j] for j in legend_order])
            # sns.move_legend(fig_ax_list_damv[-1][1], legend_position)
            fig_ax_list_damv[-1][1].get_legend().set_title('')

            for fig, ax in fig_ax_list_damv:
                ax.set_ylim(
                    min(ax.get_ylim()[0] for (fig, ax) in fig_ax_list_damv),
                    max(ax.get_ylim()[1] for (fig, ax) in fig_ax_list_damv)
                )
                # ax.set_ylim(y_lim)
                # ax.set_yticks(y_ticks, minor=False)

            for i in range(3):
                fig, ax = fig_ax_list_damv[i]
                fig.savefig(os.path.join('mvn/img', out_files[i]), bbox_inches='tight', dpi=300)










    ### DAMV weight scaling


    # field = 'var_est'
    # ground_truth_field = 'var'
    # y_axis_label = 'DAMV'
    # y_lim = (0, 3.3)
    # y_ticks = [0, 1, 2, 3]
    # legend_position = 'center right'



    ### DASME weight scaling

    # in_files = ['2023-09-09 124100.csv', '2023-09-09 144640.csv', '2023-09-10 081605.csv']
    # out_files = ['dasme-weight-gamma-above-1.png', 'dasme-weight-gamma-1.png', 'dasme-weight-gamma-below-1.png']
    # method_mapping = {
    #     '$k_{\\rm{RBF}}$': '$k_2 = k_1$',
    #     '$\\sqrt{d} \\cdot k_{\\rm{RBF}}$': '$k_2 = \sqrt{d} \cdot k_1$',
    #     '$\\log(d) \\cdot k_{\\rm{RBF}}$': '$k_2 = \log(d) \cdot k_1$'
    # }
    # legend_order = [0, 1, 2]
    # field = 'abs_mean_error'
    # ground_truth_field = 'zero'
    # y_axis_label = 'DASME'
    # y_lim = (0,0.0025)
    # y_ticks = [0, 0.001, 0.002]
    # legend_position = 'lower right'



    ### DAMV bandwidth scaling

    # in_files = ['2023-09-09 144640.csv', '2023-09-10 081605.csv', '2023-09-09 124100.csv']
    # out_files = ['damv-bandwidth-gamma-1.png', 'damv-bandwidth-gamma-below-1.png', 'damv-bandwidth-gamma-above-1.png']
    # method_mapping = {
    #     '$k_{\\rm{RBF}}$': '$h_2 = h_1$',
    #     '$k_{\\rm{RBF}}(\\cdot,\\cdot;\\sqrt{d})$': '$h_2 = \sqrt{d} \cdot h_1$',
    #     '$k_{\\rm{RBF}}(\\cdot,\\cdot;\\log(d))$': '$h_2 = \log(d) \cdot h_1$'
    # }
    # legend_order = [0, 1, 2, 3]
    # field = 'var_est'
    # ground_truth_field = 'var'
    # y_axis_label = 'DAMV'
    # y_lim = (0, 3.3)
    # y_ticks = [0, 1, 2, 3, 4, 5]
    # legend_position = 'upper left'



    ### DASME bandwidth scaling

    # in_files = ['2023-09-09 144640.csv', '2023-09-10 081605.csv', '2023-09-09 124100.csv']
    # out_files = ['dasme-bandwidth-gamma-1.png', 'dasme-bandwidth-gamma-below-1.png', 'dasme-bandwidth-gamma-above-1.png']
    # method_mapping = {
    #     '$k_{\\rm{RBF}}$': '$h_2 = h_1$',
    #     '$k_{\\rm{RBF}}(\\cdot,\\cdot;\\sqrt{d})$': '$h_2 = \sqrt{d} \cdot h_1$',
    #     '$k_{\\rm{RBF}}(\\cdot,\\cdot;\\log(d))$': '$h_2 = \log(d) \cdot h_1$'
    # }
    # legend_order = [0, 1, 2]
    # field = 'abs_mean_error'
    # ground_truth_field = 'zero'
    # y_axis_label = 'DASME'
    # y_lim = (0, 0.02)
    # y_ticks = [0, 0.002, 0.004, 0.006]
    # legend_position = 'upper left'


    # field = 'var_est'
    # ground_truth_field = 'var'
    # y_axis_label = 'DAMV'



    # legend = [False, False, True]

    # fig_ax_list_damv = [plt.subplots(figsize=(4, 2.5)) for _ in range(3)]
    # # fig_ax_list_mean = [plt.subplots(figsize=(4, 3)) for _ in range(3)]


    # for i, f in enumerate(in_files):
    #     df = pd.read_csv(os.path.join(dir, f)).iloc[:,1:]
    #     df = df[df.k.isin(method_mapping.keys())]
    #     df['abs_mean_error'] = (df['mean']-df['mean_est'])**2

    #     df_agg = df.\
    #         groupby(['d', 'exp', 'k', 'k_index']).\
    #         agg(
    #             var=('var', 'mean'),
    #             var_est=('var_est', 'mean'),
    #             mean=('mean', 'mean'),
    #             mean_est=('mean_est', 'mean'),
    #             abs_mean_error=('abs_mean_error', 'mean')
    #         ).\
    #         reset_index()
    #     df_agg['mean_error'] = df_agg['mean'] - df_agg['mean_est']
    #     df_agg['zero'] = 0

    #     df_agg['method_latex'] = df_agg.k.replace(method_mapping)

    #     if field=='var_est':
    #         sns.lineplot(
    #             data=df_agg,
    #             x='d',
    #             y=ground_truth_field,
    #             color='black',
    #             linestyle='--',
    #             ax=fig_ax_list_damv[i][1],
    #             label='ground truth',
    #             legend=legend[i]
    #         )

    #     sns.lineplot(
    #         data=df_agg,
    #         x='d',
    #         y=field,
    #         hue='method_latex',
    #         hue_order=method_mapping.values(),
    #         ci=field=='var_est',
    #         ax=fig_ax_list_damv[i][1],
    #         legend=legend[i]
    #     )

    #     fig_ax_list_damv[i][1].set_xlabel('Dimension')
    #     fig_ax_list_damv[i][1].set_ylabel(y_axis_label)

    # handles, labels = plt.gca().get_legend_handles_labels()
    # plt.legend([handles[j] for j in legend_order], [labels[j] for j in legend_order])
    # sns.move_legend(fig_ax_list_damv[-1][1], legend_position)
    # fig_ax_list_damv[-1][1].get_legend().set_title('')

    # for fig, ax in fig_ax_list_damv:
    #     ax.set_ylim(
    #         min(ax.get_ylim()[0] for (fig, ax) in fig_ax_list_damv),
    #         max(ax.get_ylim()[1] for (fig, ax) in fig_ax_list_damv)
    #     )
    #     # ax.set_ylim(y_lim)
    #     # ax.set_yticks(y_ticks, minor=False)

    # for i in range(3):
    #     fig, ax = fig_ax_list_damv[i]
    #     fig.savefig(os.path.join('mvn/img', out_files[i]), bbox_inches='tight', dpi=300)