
import matplotlib.pyplot as plt
import os
import pickle
from matplotlib import cm
import matplotlib.transforms as mtrans

import sys
import os

import dataclasses

# Get absolute path of the project root
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
from config_generator import generate_configurations, generate_string_from_params, generate_name_from_params, load_configurations

import numpy as np

from plot_helpers import plotting_params, metric_clean_names, method_names, colors, children_labels, linestyles, nodestyles, vary_labels, save_legend, convert_to_float_list, convert_to_int_list, convert_to_string_list, convert_to_children_known



def get_metrics_file_path(metrics_folder, config, by_key=None, by_value=None, regime_children_value=None, prefix=None, suffix=None):
    if regime_children_value is not None:
        config = dataclasses.replace(config, reg_children_known=regime_children_value)

    key = str(by_key)  # fixes np.str_ etc.
    names = {f.name for f in dataclasses.fields(config)}
    if key not in names:
        if key == 'N_values':
            by_key = 'N'
        else:
            raise AttributeError(f"{type(config).__name__} has no field {key!r}")
    config = dataclasses.replace(config, **{by_key: by_value})

    file_name = generate_name_from_params(config)
    if suffix is None:
        suf = ''
    else:
        suf = suffix
    metrics_file_path = os.path.join(metrics_folder, file_name + f'_{prefix}{suf}.dat')

    return metrics_file_path

def load_avg_metrics_regime_by(metric, method, by_key, by_values, metrics_folder, config, regime_children_value):
    metric_pp_list, metric_errors_list = [], []
    failed_files = []

    for by_value in by_values:
        value = config.check_type(by_key, by_value)
        metrics_file_path = get_metrics_file_path(metrics_folder, config, by_key, by_value, regime_children_value, prefix='avg_', suffix='regimes')
        # try:
        with open(metrics_file_path, 'rb') as file:
            metrics_dict = pickle.load(file, encoding='latin1')
        metrics_key = metric + method
        if metrics_key in metrics_dict:
            metric_pp_list.append(metrics_dict[metrics_key][0])
            metric_errors_list.append(metrics_dict[metrics_key][1])
    return metric_pp_list, metric_errors_list, failed_files
    
def generate_time_conv_plot_regimes(all_configurations, vary_key, vary_values, regime_children_known=None, metrics_folder=None, axs=None, plot_avg='False', metrics_list=['tpr', 'fpr'], show_xticks=False):
    config = all_configurations[0]

    for j, metric in enumerate(metrics_list):
        if len(metrics_list) > 1:
            ax = axs[j]
        else:
            ax = axs[j]
        for regime_children_value in regime_children_known:
            for method in ['_persistent_regimes', '_sparse_regimes', '_ymask', '_ymask_naive', '_intersection']:
                metric_pp_list, metric_errors_list, failed_files = load_avg_metrics_regime_by(metric, method, vary_key, vary_values, metrics_folder, config, regime_children_value)   
                ax.plot(vary_values[:len(metric_pp_list)], metric_pp_list, nodestyles[method], markersize=plotting_params['lines.markersize'], label=f"{method_names[method]} adj. {children_labels[str(regime_children_value)]}", linestyle=linestyles.get(str(regime_children_value)), color=colors.get(method))
                ax.errorbar(vary_values[:len(metric_pp_list)], metric_pp_list, yerr=metric_errors_list, color=colors.get(method), linestyle=linestyles.get(str(regime_children_value)))
            if show_xticks:
                ax.set_xticks(vary_values)
                ax.set_xticklabels(vary_values, rotation=40, ha='right', fontsize=plotting_params['xtick.labelsize'])

            if 'tpr' in metric:
                xticks = [0.0, 0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.]
                ax.set_yticks(xticks) 
                ax.set_yticklabels([f"{i:.2f}" for i in xticks], 
                               rotation=0, ha='right', fontsize=plotting_params['ytick.labelsize'])
            elif 'fpr' in metric:
                yticks = [0., 0.01, 0.03, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]
                ax.set_yticks(yticks)  # Assuming you want 6 ticks from 0 to 1
                ax.set_yticklabels([f"{i:.2f}" for i in yticks], 
                                   rotation=0, ha='right', fontsize=plotting_params['ytick.labelsize'])
                ax.set_ylim([0, 0.1])

            ax.grid(color='lightgray')

            if 'fpr' in metric:
                ax.hlines(y=0.05, xmin=vary_values[0], xmax=vary_values[-1], colors='black', linestyles='-', lw=1)
            
            ylabel = 'Avg. Contexts'
            if j==0:
                ax.set_ylabel(ylabel, fontsize=plotting_params['title.size'])

            ax.set_xlabel(vary_labels[vary_key], fontsize=plotting_params['title.size'])


def load_metrics_file_union(metric, method, by_key, by_values, metrics_folder, config, regime_children_value, suffix=None):
    metric_pp_list, metric_errors_list = [], []
    failed_files = []
    
    for by_value in by_values:
        value = config.check_type(by_key, by_value)
        metrics_file_path = get_metrics_file_path(metrics_folder, config, by_key, by_value, regime_children_value, 'union', suffix=suffix)
        with open(metrics_file_path, 'rb') as file:
            metrics_dict = pickle.load(file, encoding='latin1')
        metric_key = metric + method
        if metric_key in metrics_dict:
            metric_pp_list.append(metrics_dict[metric_key][0])
            metric_errors_list.append(metrics_dict[metric_key][1])
    return metric_pp_list, metric_errors_list, failed_files
                

def generate_time_conv_plot_union(all_configurations, vary_key, vary_values, regime_children_known=None, metrics_folder=None, metrics_folder_fci=None, axs=None, plot_avg='False', metrics_list=['tpr', 'fpr'], plot_fci='False', show_xticks=False):

    metrics_list = ['union_' + metric for metric in metrics_list]
    
    config = all_configurations[0]

    fig_save_foldername = '/plot_figures_fci/'
    file_name = generate_name_from_params(config)
    figure_path = metrics_folder + fig_save_foldername if metrics_folder else config.save_folder + fig_save_foldername
    os.makedirs(figure_path, exist_ok=True)
    
    
    for j, metric in enumerate(metrics_list):
        if not 'Ax' in str(type(axs)):
            ax = axs[j]
        else:
            ax = axs
            
        for regime_children_value in regime_children_known:
            for method in ['_persistent_regimes', '_sparse_regimes', '_ymask', '_ymask_naive', '_intersection', '_pcmci']:
                metric_pp_list, metric_errors_list, failed_files = load_metrics_file_union(metric, method, vary_key, vary_values, metrics_folder, all_configurations[0], regime_children_value)
                ax.plot(vary_values[:len(metric_pp_list)], metric_pp_list, nodestyles[method],  markersize=plotting_params['lines.markersize'], label=f"{method_names[method]} adj. {children_labels[str(regime_children_value)]}", linestyle=linestyles.get(str(regime_children_value)), color=colors.get(method))
                ax.errorbar(vary_values[:len(metric_pp_list)], metric_pp_list, yerr=metric_errors_list, color=colors.get(method), linestyle=linestyles.get(str(regime_children_value)))

            if show_xticks:
                    ax.set_xticks(vary_values)
                    ax.set_xticklabels(vary_values, rotation=40, ha='right', fontsize=plotting_params['xtick.labelsize'])

            if 'tpr' in metric:
                xticks = [0.0, 0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.]
                ax.set_yticks(xticks) 
                ax.set_yticklabels([f"{i:.2f}" for i in xticks], 
                                rotation=0, ha='right', fontsize=plotting_params['ytick.labelsize'])
            elif 'fpr' in metric:
                yticks = [0., 0.01, 0.03, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]
                ax.set_yticks(yticks)  # Assuming you want 6 ticks from 0 to 1
                ax.set_yticklabels([f"{i:.2f}" for i in yticks], 
                                rotation=0, ha='right', fontsize=plotting_params['ytick.labelsize'])
                ax.set_ylim([0, 0.1])
                
            
            if metric == 'union_fpr':
                ax.hlines(y=0.05, xmin=min(vary_values), xmax=max(vary_values), colors='black', linestyles='-', lw=1)
            if j == 0:
                ax.set_ylabel('Union', fontsize=plotting_params['title.size'])

            ax.set_title(metric_clean_names[metric], fontsize=plotting_params['title.size'])
            ax.grid(color='lightgray')




def generate_fci_plots(all_configurations, sample_sizes, regime_children_known, metrics_folder=None, metrics_folder_fci=None, metrics_list=['tpr', 'fpr'], plot_avg='False', plot_fci='False'):
    config = all_configurations[-1]
    
    fig_save_foldername = 'plot_figures_fci/'
    figure_path = os.path.join(metrics_folder, fig_save_foldername)
    os.makedirs(figure_path, exist_ok=True)
    
    file_name = generate_name_from_params(config)
    if len(metrics_list) == 1:
        fig, axs = plt.subplots(1, len(metrics_list), figsize=(plotting_params['common_fig_width'] // 2, plotting_params['common_fig_height'] // 2 + 2))
    else:
        fig, axs = plt.subplots(1, len(metrics_list), figsize=(plotting_params['common_fig_width'], plotting_params['common_fig_height'] // 2 + 1))
        plt.subplots_adjust(hspace=0.25, wspace=0.2, top=0.79)

    if remove_only == True:
        ops = 'Remove'
    else:
        ops = 'Remove, Add, Flip'
    fig.suptitle(f"Nodes: {config.N}, Contexts: {config.nb_regimes}, Changed Links: {config.nb_changed_links}, Density: {config.density},\nCycles: {config.cycles_only}, Endo regime: {config.endo_regime}", fontsize=plotting_params['title.size'])
    
    # Plot Union Metrics
    generate_time_conv_plot_union(
        all_configurations, vary_key, vary_values, sample_sizes=sample_sizes, regime_children_known=regime_children_known, 
        metrics_folder=metrics_folder, metrics_folder_fci=metrics_folder_fci, axs=axs, metrics_list=metrics_list, plot_avg=plot_avg, plot_fci=plot_fci
    )
    
    ax_for_handles = axs[0]
    handles, labels = ax_for_handles.get_legend_handles_labels()

    fig_legend = plt.figure(figsize=(4, 3))
    fig_legend.legend(handles, labels, loc='lower center', ncol=5,  bbox_to_anchor=(0.5, -0.05), fontsize=plotting_params['legend.fontsize'])
    plt.tight_layout()
    # Save the legend as a separate file
    fig_legend.savefig(os.path.join(figure_path, 'legend_fci' + '_'.join(str(known) for known in regime_children_known) + '.png'), bbox_inches='tight')
    plt.close(fig_legend)
    
    plt.savefig(os.path.join(figure_path, file_name + '_fci_' + '_'.join(metric for metric in metrics_list) + '_' + '_'.join(str(known) for known in regime_children_known) + '.png'), bbox_inches='tight')
    plt.close(fig)

def generate_combined_plots(all_configurations, 
                            vary_key, 
                            vary_values,
                            regime_children_known,
                            metrics_folder=None,
                            metrics_folder_fci=None,
                            metrics_list=['tpr', 'fpr'],
                            plot_avg='False',
                            plot_fci='False',
                            suffix=''):
    
    config = all_configurations[0]
    
    fig_save_foldername = 'plot_figures_fci/'
    figure_path = os.path.join(metrics_folder, fig_save_foldername)
    os.makedirs(figure_path, exist_ok=True)
    
    file_name = generate_name_from_params(config)
        
    if plot_avg == 'False':
        fig, axs = plt.subplots(1 + config.nb_regimes,
                                len(metrics_list),
                                figsize=(plotting_params['common_fig_width'],
                                         plotting_params['common_fig_height'] + 2 * config.nb_regimes))
        plt.subplots_adjust(hspace=0.25, wspace=0.2, top=0.92)
    else:
        if len(metrics_list) == 1:
            fig, axs = plt.subplots(2,
                                    len(metrics_list),
                                    figsize=(plotting_params['common_fig_width'] // 2,
                                             plotting_params['common_fig_height'] // 2 + 2))
        else:
            fig, axs = plt.subplots(2,
                                    len(metrics_list),
                                    figsize=(plotting_params['common_fig_width'],
                                             plotting_params['common_fig_height'] + 2))
            
        plt.subplots_adjust(hspace=0.25, wspace=0.2, top=0.87)
    
    fig.suptitle(f"$D$={config.N - 1}, $n_{{\\text{{contexts}}}}$={config.nb_regimes}, "
    f"$n_{{\\text{{change}}}}$={config.nb_changed_links}, $s$={config.density}, $f_c$={config.contemp_fraction}\n$endo_c$={config.endo_regime}, $contemp_c$={config.contemp_context}",
    fontsize=plotting_params['title.size'])
    generate_time_conv_plot_union(
        all_configurations, vary_key=vary_key, vary_values=vary_values, regime_children_known=regime_children_known, 
        metrics_folder=metrics_folder, metrics_folder_fci=metrics_folder_fci, axs=axs[0], metrics_list=metrics_list, plot_avg=plot_avg, plot_fci=plot_fci,
        show_xticks=True
    )
    
    # Plot Regime Metrics
    generate_time_conv_plot_regimes(
        all_configurations,vary_key=vary_key, vary_values=vary_values, regime_children_known=regime_children_known,
        metrics_folder=metrics_folder, axs=axs[1:].flatten(), metrics_list=metrics_list, plot_avg=plot_avg,
        show_xticks=True
    )
    if len(metrics_list) > 1:
        ax_for_handles = axs[0,0]
    else:
        ax_for_handles = axs[0]
    handles, labels = ax_for_handles.get_legend_handles_labels()

    fig_legend = plt.figure(figsize=(4, 3))
    fig_legend.legend(handles, labels, loc='lower center', ncol=1,  bbox_to_anchor=(0.5, -0.05), fontsize=plotting_params['legend.fontsize'])
    plt.tight_layout()
    # Save the legend as a separate file
    fig_legend.savefig(os.path.join(figure_path, 'legend_h' + '_'.join(str(known) for known in regime_children_known) + suffix + '.png'), bbox_inches='tight')
    plt.close(fig_legend)
    filename = os.path.join(figure_path, file_name + '_combined_' + '_'.join(metric for metric in metrics_list) + '_' + '_'.join(str(known) for known in regime_children_known) + suffix + '.png')
    plt.savefig(filename, bbox_inches='tight')
    plt.close(fig)


    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description="Generate configurations from a YAML file.")
    parser.add_argument('--yaml_path', type=str, help='Path to the YAML configuration file.')
    parser.add_argument('--metrics_list', type=convert_to_string_list, help='List of metrics to plot.')
    parser.add_argument('--vary_key', type=str, required=False, help='Key for which the X axis should vary.')
    parser.add_argument('--vary_values', type=convert_to_float_list, required=False, help='Values for which the X axis should vary.')
    parser.add_argument('--metrics_folder', type=str, help='Folder where the metrics can be found.')
    parser.add_argument('--metrics_folder_fci', type=str, help='Folder where the metrics for FCI and CD-NOD can be found.')
    parser.add_argument('--plot_avg', type=str, required=False, help='Whether to plot averages.')
    parser.add_argument('--plot_known', type=convert_to_children_known, help='Which link assumptions to plot.')
    parser.add_argument('--plot_fci', type=str, help='Whether to plot fci.')
    

    args = parser.parse_args()
    config_path = args.yaml_path
    results_folder, all_configurations = generate_configurations(config_path)

    if args.plot_avg is None:
        args.plot_avg = 'True'

    if args.plot_fci is None:
        args.plot_fci = 'False'
    
    for node in config_parameters['N_values']:
        for factor in config_parameters['imbalance_factor']:
            for nb_links in config_parameters['nb_changed_links']:
                for endo_regime in config_parameters['endo_regime']:
                    suffix = f"_N{node}_imb{factor}_links{nb_links}_endo{endo_regime}"
                    configs_for_nodes = [config for config in all_configurations if config.N == node and config.imbalance_factor == factor and config.nb_changed_links == nb_links and config.endo_regime == endo_regime]
                    
                    if args.vary_values is None:
                        vary_values = config_parameters[args.vary_key]
                        print('vary values', config_parameters[args.vary_key])
                    else:
                        vary_values = args.vary_values

                    generate_combined_plots(
                        configs_for_nodes, 
                        vary_key=args.vary_key, 
                        vary_values=vary_values,
                        regime_children_known=args.plot_known, 
                        metrics_folder=args.metrics_folder,
                        metrics_folder_fci=args.metrics_folder_fci,
                        metrics_list=args.metrics_list,
                        plot_avg=args.plot_avg,
                        plot_fci=args.plot_fci,
                        suffix=suffix
                    )