
import os
import dill
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import argparse

dpi = 300
matplotlib.rcParams['font.family'] = 'serif'
matplotlib.rcParams['font.serif'] = 'Times New Roman'
matplotlib.rcParams["axes.edgecolor"] = "black"
matplotlib.rcParams["axes.linewidth"] = 1
matplotlib.rcParams['font.weight'] = 'bold'
matplotlib.rcParams['axes.labelweight'] = 'bold'
matplotlib.rcParams['axes.titleweight'] = 'bold'
matplotlib.rcParams['axes.linewidth'] = 1.5
matplotlib.rcParams['xtick.labelsize'] = 'x-large'
matplotlib.rcParams['ytick.labelsize'] = 'x-large'

matplotlib.rcParams['mathtext.fontset'] = 'custom'
matplotlib.rcParams['mathtext.it'] = 'STIXGeneral:italic'
matplotlib.rcParams['mathtext.bf'] = 'STIXGeneral:italic:bold'



label_mapping = {
    "fairproj-kl-sp": "FairProj (KL)",
    "fairproj-kl-eo": "FairProj (KL)",
    "fairproj-ce-sp": "FairProj (CE)",
    "fairproj-ce-eo": "FairProj (CE)",
    "reduction-sp": "Reduction",
    "reduction-eo": "Reduction",
    "rejection-sp": "Rejection",
    "rejection-eo": "Rejection",
    "postprocess-sp": "LinearPost",
    "postprocess-eo": "LinearPost",
    "eqodds": "EqOdds",
    "caleqodds": "CalEqOdds",
    "Base": "Base",
    "postprocess-reg": "LinearPost",
    "wasserstein": "WassBC",
    "fairreg": "FairReg",
    }

metric_mapping = {
        'SP': 'SP',
        'EO': 'EO',
        'SP (pqi, weak)': r'$S$-SP',
        'EO (pqi)': r'$S$-EO',
        'EO (pqi, pos)': r'$S$-EO$^{+}$',
        'SP (gini, weak)': r'$S$-SP (Gini)',
        'EO (gini)': r'$S$-EO (Gini)',
        'EO (gini, pos)': r'$S$-EO (Gini)^{+}$',
        'SP (pqi, max)': r'$S$-SP ($\mathit{KS}$)',
        'SP (weak)': r'SP ($\mathit{weak}$)',
        'SP (ks)': r'SP ($\mathit{KS}$)',
        'SP (w1)': r'SP ($\mathit{W}$)',
        'SP (pqi, int)': r'$S$-SP ($\mathit{W}$)',
    }

plt_dict = {
        'label.size': 35,
        'title.size': 40,
        'legend.size': 28,
        'xtick.size': 30,
        'legend.loc': 'upper right'
    }

def make_plot(dataset, metric, mode = 'c'):
    if mode == 'c':
        model_met = 'Accuracy'
    elif mode == 'r':
        model_met = 'RMSE'
    plotdf = pd.DataFrame(columns=['methods', f'{model_met}', 'metric', 'seed', 'param'])
    result_path = os.path.join('output', 'result')
    experiments = os.listdir(result_path)
    experiments = [e for e in experiments if (dataset in e) and ('baseline' in e)]
    for exp in experiments:
        with open(os.path.join(result_path, exp), 'rb') as f:
            results = dill.load(f)
        results_mets = results['logger']['tracker']
        method_name = exp.split('_')[3]
        seed = int(exp.split('_')[0])
        param = exp.split('_')[-1]
        res_row = {'methods': method_name, f'{model_met}': results_mets[f'test/{model_met}'], 
                   'metric': results_mets[f'test/{metric}'], 'seed': seed, 'param': param}
        # acc * 100 if acc < 1
        if mode == 'c' and res_row[f'{model_met}'] < 1:
            res_row[f'{model_met}'] *= 100
        
        plotdf = pd.concat([plotdf, pd.DataFrame(res_row, index=[0])], axis=0, ignore_index=True)


    # rename method none to Base
    plotdf['methods'] = plotdf['methods'].replace('none', 'Base')

    
    # excluded methods
    if 'SP' in metric:
        excluded_methods = ['fairproj-kl-eo', 'fairproj-ce-eo', 'reduction-eo', 'rejection-eo', 'postprocess-eo']
    elif 'EO' in metric:
        excluded_methods = ['fairproj-kl-sp', 'fairproj-ce-sp', 'reduction-sp', 'rejection-sp', 'postprocess-sp']

    plotdf = plotdf[~plotdf['methods'].isin(excluded_methods)]
    return plotdf


def make_figure(plotdf, metric, dataset, mode = 'c'):    

    model_met = 'Accuracy' if mode == 'c' else 'RMSE'

    grouped_plotdf = (
        plotdf.groupby(['param', 'methods'])
        .agg(
            mean_model_met=(model_met, 'mean'),
            se_model_met=(model_met, lambda x: np.std(x) / np.sqrt(len(x) if np.unique(x).shape[0] > 1 else 0)),
            mean_metric=('metric', 'mean'),
            se_metric=('metric', lambda x: np.std(x) / np.sqrt(len(x)) if np.unique(x).shape[0] > 1 else 0),
        )
        .reset_index()
    )
    # filter out data points wurg fairness metric greater than base line + 3 se
    base_threshold = grouped_plotdf.loc[grouped_plotdf['methods'] == 'Base', 'mean_metric'] + 3 * grouped_plotdf.loc[grouped_plotdf['methods'] == 'Base', 'se_metric']
    grouped_plotdf = grouped_plotdf[grouped_plotdf['mean_metric'] <= base_threshold.values[0]]


    # Reorder methods to ensure "none" is last
    method_order = [m for m in grouped_plotdf['methods'].unique() if m != 'Base'] + ['Base']
    grouped_plotdf['methods'] = pd.Categorical(grouped_plotdf['methods'], categories=method_order, ordered=True)
    grouped_plotdf.sort_values('methods', inplace=True)

    methods = grouped_plotdf['methods'].unique()
    # Define custom markers
    available_markers = ['o', 's', 'D', '^', 'v', '<', '>', 'p', 'h', '+', 'x']  
    custom_markers = {}
    marker_index = 0
    for method in methods:
        if method == 'Base':
            custom_markers[method] = '*'
        else:
            # Assign a marker from the list
            custom_markers[method] = available_markers[marker_index]
            marker_index = (marker_index + 1) % len(available_markers)

    custom_size = {method: 30 if method == 'Base' else 15 for method in grouped_plotdf['methods'].unique()}
    # Create the line plot
    colors = plt.cm.tab10(np.linspace(0, 1, len(methods)-1))  # Get colors for non-Base methods
    color_dict = {method: "#006400" if method == 'Base' else colors[i] 
                 for i, method in enumerate([m for m in methods if m != 'Base'] + ['Base'])}
    
    
    plt.figure(figsize=(10, 8))
    ax = plt.gca()
    for method, method_data in grouped_plotdf.groupby('methods'):
        line = sns.lineplot(
            data=method_data,
            x='mean_metric',
            y='mean_model_met',
            marker=custom_markers[method], 
            markeredgecolor='none',
            markersize=custom_size[method],
            label=method,
            err_style='bars',
            errorbar=None,
            linewidth=1.5,
            color=color_dict[method],
            ax=ax,
        )
        
        lcolor = line.get_lines()[-1].get_color()

        # Add error bars manually
        plt.errorbar(
            x=method_data['mean_metric'],
            y=method_data['mean_model_met'],
            xerr=method_data['se_metric'],
            yerr=method_data['se_model_met'],
            fmt='none',  # No marker here
            ecolor=lcolor,
            elinewidth=1.,
            capsize=0,
            alpha=0.7
        )
    # Extract and map labels
    handles, labels = plt.gca().get_legend_handles_labels()
    new_labels = [label_mapping[label] for label in labels]
    # Add plot labels and title
    plt.title(f'{metric_mapping[metric]}', fontsize=plt_dict['title.size'])
    plt.locator_params(axis='both', nbins=6) 
    plt.xlabel(f'', fontsize=16)
    plt.ylabel(f'{model_met}', fontsize=plt_dict['label.size'])
    plt.xticks(fontsize=plt_dict['xtick.size'])
    plt.yticks(fontsize=plt_dict['xtick.size'])
    # Display the legend and show the plot
    plt.legend(handles, new_labels, title='', fontsize=plt_dict['legend.size'], loc =plt_dict['legend.loc']) 

    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    print(f'Saving {dataset}_{metric}...')
    plt.savefig(f'output/figs/{dataset}_{metric}.pdf', 
                dpi=dpi, format='pdf', bbox_inches='tight')
    plt.close()


def main():
    parser = argparse.ArgumentParser(description='make figures for classification or regression')
    parser.add_argument('--task', default='clf', type=str)
    args = parser.parse_args()

    if args.task == 'clf':
        clf_dataset = ['Adult', 'SimulateC']
        fair_metrics = ['SP', 'EO', 'SP (pqi, weak)', 'EO (pqi)', 'EO (pqi, pos)', 'SP (gini, weak)', 'EO (gini)', 'EO (gini, pos)']

        for dataset in clf_dataset:
            for metric in fair_metrics:
                plotdf = make_plot(dataset, metric, mode = 'c')
                make_figure(plotdf, metric, dataset, mode = 'c')
    
    elif args.task == 'reg':
        reg_dataset = ['SimulateR']
        fair_metrics = ['SP (pqi, weak)', 'SP (ks)', 'SP (w1)', 'SP (weak)', 'SP (pqi, int)', 'SP (pqi, max)', 'EO', 'EO (pqi)', 'EO (gini)','EO (pqi, pos)', 'EO (gini, pos)']


        for dataset in reg_dataset:
            for metric in fair_metrics:
                plotdf = make_plot(dataset, metric, mode = 'r')
                make_figure(plotdf, metric, dataset, mode = 'r')
    
    
if __name__ == "__main__":
    main()