from pathlib import Path
import pdb
import pandas as pd
import pyperclip
import seaborn as sns
import numpy as np
from glob import glob
import optuna
import os
import pareto

import logging, sys
import matplotlib.pyplot as plt

DISABLE_LOG = True
logging.disable(sys.maxsize)

if DISABLE_LOG:
    log = lambda *args, **kwargs: None
else:
    log = print

def plot_pareto(df, x_col, y_col, z_col=None, q=None, budget_col_name='budget', seed_col_name='seed'):
    col_1 = np.where(df.columns == x_col)[0].item()
    col_2 = np.where(df.columns == y_col)[0].item()

    if z_col is not None:   
        col_3 = np.where(df.columns == z_col)[0].item()
        df = df.groupby([budget_col_name, seed_col_name]).apply(lambda _df: \
                                                pd.DataFrame(pareto.eps_sort([list(_df.itertuples(False))], [col_1, col_2, col_3]), 
                                                                columns=list(df.columns.values)))
    else:
        df = df.groupby([budget_col_name, seed_col_name]).apply(lambda _df: \
                                                pd.DataFrame(pareto.eps_sort([list(_df.itertuples(False))], [col_1, col_2]), 
                                                                columns=list(df.columns.values)))
    
    if q is not None:
        if z_col is not None:
            # df = df[["test_error", "test_disparity", "test_coverage"]].reset_index(drop=False)
            # df['coverage_q'] = pd.qcut(df['test_coverage'], q=q, duplicates='drop')
            # df['coverage_mid'] = df['coverage_q'].apply(lambda x: x.mid)

            df = df[[x_col, y_col, z_col]].reset_index(drop=False)
            df[f'{z_col}_q'] = pd.qcut(df[z_col], q=q, duplicates='drop')
            df[f'{z_col}_mid'] = df[f'{z_col}_q'].apply(lambda x: x.mid)
        # else:
            # df = df[["test_error", "test_disparity"]].reset_index(drop=False)
            # df = df[[x_col, y_col]].reset_index(drop=False)
        
        # df['error_q'] = pd.qcut(df['test_error'], q=q, duplicates='drop')
        df[f'{x_col}_q'] = pd.qcut(df[x_col], q=q, duplicates='drop')

        # df['error_mid'] = df['error_q'].apply(lambda x: x.mid)
        df[f'{x_col}_mid'] = df[f'{x_col}_q'].apply(lambda x: x.mid)
    else:
        # df['error_mid'] = df[x_col]
        df[f'{x_col}_mid'] = df[x_col]

    return df

def read_baselines():
    return pd.read_parquet('Figures/baseline_results.parquet').replace(
        {"Fairness Metric": {"Equalized Odds": "EqualityOfOdds", "Demographic Parity": "DemParity"},
         "Model": {"Tran et al. (2021a)": "Tran et al. (2021)"}})

def plot_baseline_over(ax, dataset, eps, model, fairness_metric, plot_kwargs={}):
    
    baseline_results = read_baselines()
    # pdb.set_trace()
    
    query = f'Epsilon == {eps} and Dataset == "{dataset}" and Model == "{model}" and `Fairness Metric` == "{fairness_metric}"'
    try:
        baseline = baseline_results.query(query).iloc[0]
    except IndexError:
        log(f"No baseline results found for {query}")
        return
    
    x = baseline['Misclassification Error']
    y = baseline['Disparity']
    y_upper = baseline['Disparity Upper']
    y_lower = baseline['Disparity Lower']

    ax.fill_between(x, y_lower, y_upper, alpha = 0.5, label = model, **plot_kwargs)
    ax.plot(x, y, **plot_kwargs)


def choose_db(dataset):
    if dataset == 'adult':
        return 'adult_results.parquet'
    elif dataset == 'retired-adult':
        return 'retired-adult_results.parquet'
    elif dataset == 'parkinsons':
        return 'parkinsons_results.parquet'
    elif dataset == 'credit-card':
        return 'credit_card_results.parquet'
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    


def pretty_print(fairness_metric, ignore_metric=False):
    if ignore_metric:
        return "Disparity"
    
    if fairness_metric == 'EqualityOfOdds':
        return "Disparity (EqOdds)"
    elif fairness_metric == 'DemParity':
        return "Disparity (DemParity)"
    elif fairness_metric == 'ErrorParity':
        return "Disparity (ErrorParity)"
    else:
        raise ValueError(f"Unknown fairness metric: {fairness_metric}")

def plot_dataset(dataset, return_results=False, q=10, take_top=None):
    db = choose_db(dataset)
    results = pd.read_parquet(db).query(f'dataset=="{dataset}"')
    results["student_model_error"] = results["student_model_accuracy"].apply(lambda x: 1-x)

    # taking the result that maximizes accuracy
    if take_top is None:
        results = results.loc[results.groupby(['dataset', 'seed', 'budget', 'fairness_threshold'])['student_model_accuracy'].idxmax()]
    else:
        results = results.sort_values(by=['student_model_error', 'dem_disparity'], ascending=True).head(take_top)

    # Some curve smoothing
    results['error_q'] = pd.qcut(results['student_model_error'], q=q, duplicates='drop')
    results['error_mid'] = results['error_q'].apply(lambda x: x.mid)

    if return_results:
        return results
    else:
        g = sns.relplot(data=results.query(f'dataset=="{dataset}"'), x="error_mid", y="dem_disparity", col="budget", kind='line', col_wrap=4)
        # g.set(ylim=(0.0, 0.2), xlim=(0.15, 0.23))
        g.tight_layout()
        return g
    
def plot_epsilon_grid(fairness_metric, dataset, epsilons, axes, q=10, take_top=None, plot_only_baselines=False):
    color_dict = {
        # 'DP-FERMI' : 'blue',
        'Tran et al. (2021)': 'orange',
        'Jagielski et al. (2019)' : 'red'
    }
    for ax, eps in zip(axes, epsilons):
        if not plot_only_baselines:
            dataset_results = plot_dataset(dataset, return_results=True, q=q, take_top=take_top)
            sns.lineplot(data=dataset_results.query(f'dataset == "{dataset}" and budget == {eps}'), x="error_mid", y="dem_disparity", ax=ax, label="FairPATE", color='magenta')
        
        baseline_results = read_baselines()
        for m in baseline_results.Model.value_counts().index.tolist():
            if eps == 10_000:
                eps = np.inf
                # eps = '\infty'
            plot_baseline_over(ax, dataset, eps, m, fairness_metric, plot_kwargs=dict(color=color_dict.get(m, None)))
    
        ax.set_title("$\\varepsilon=" + str(eps) + "$")
        # ax.annotate(f"$\\varepsilon={eps}$", xy=(0.5, 0.1), xycoords='axes fraction', fontsize=14, ha='center', va='center')
        # ax.legend()


def get_optuna_results(optuna_path, dataset, fairness_metric, budget_to_ax_id, has_fairness_threshold=False, model=None):
    dfs = []
    budgets = set()
    
    if not has_fairness_threshold:
        if model is None:
            # pdb.set_trace()
            db_list = glob(os.path.join(optuna_path, f"{dataset}_{fairness_metric}_budget_*_seed_*.db"))
        else:
            db_list = glob(os.path.join(optuna_path, f"{model}_{dataset}_{fairness_metric}_budget_*_seed_*.db"))

        for db_path in db_list:
            log(Path(db_path).resolve().absolute())
            study_name = db_path.split('/')[-1].replace('.db', '')
            storage_name = f"sqlite:////{Path(db_path).resolve().absolute()}"
            log(storage_name)
            
            seed = int(study_name.split('_')[-1])
            budget = float(study_name.split('_')[-3])

            if budget not in budget_to_ax_id.keys():
                log(f"Skipping {budget}")
                continue
            

            study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
            df = study.trials_dataframe(attrs=("number", "value", "params", "state", "user_attrs"))
            if len(df) == 0:
                log(f"{study_name} is empty")
                continue
            df = df.assign(seed=seed, budget=budget)
            dfs.append(df)
            budgets.add(budget)
    else:
        # pdb.set_trace()
        db_list = glob(os.path.join(optuna_path, f"{model}_{dataset}_{fairness_metric}_fairnessThreshold_*_budget_*_seed_*.db"))
        for db_path in db_list:
            # log(Path(db_path).resolve().absolute())
            study_name = db_path.split('/')[-1].replace('.db', '')
            storage_name = f"sqlite:////{Path(db_path).resolve().absolute()}"
            log(storage_name)
            
            seed = int(study_name.split('_')[-1])
            budget = float(study_name.split('_')[-3])
            fairness_threshold = float(study_name.split('_')[-5])

            if budget not in budget_to_ax_id.keys():
                log(f"Skipping {budget}")
                continue

            study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
            df = study.trials_dataframe(attrs=("number", "value", "params", "state", "user_attrs"))
            if len(df) == 0:
                log(f"{study_name} is empty")
                continue
            df = df.assign(seed=seed, budget=budget, fairness_threshold=fairness_threshold)
            dfs.append(df)
            budgets.add(budget)

    plot_df = pd.concat(dfs, axis=0).reset_index(drop=True).replace(1000, np.nan).replace(-1000, np.nan).dropna()
    plot_df['test_accuracy'], plot_df['test_error'] = (plot_df['user_attrs_student_model_test_accuracy'], 1- plot_df['user_attrs_student_model_test_accuracy'])
    plot_df['test_disparity'] = plot_df['user_attrs_test_dem_parity']

    if model is not None and "IPP" in model:
        plot_df['test_coverage'] = plot_df['user_attrs_test_coverage']
        plot_df['test_rejection'] = 1 - plot_df['test_coverage']
    
    return plot_df, budgets


# def plot_results(dataset, fairness_metric, plot_only_baselines = False, optuna_path="."):
#     if fairness_metric == 'EqualityOfOdds':
#         fig, axes = plt.subplots(1, 3, figsize=(10, 5), sharey=True)
#         plot_epsilon_grid(fairness_metric, dataset, [0.5, 1, 3], axes.ravel(), q=20, take_top=None, plot_only_baselines=plot_only_baselines)
#         budget_to_ax_id = {0.5:0, 1.0: 1, 3.0: 2}
#     elif fairness_metric == 'DemParity':
#         fig, axes = plt.subplots(2, 2, figsize=(8, 8), sharey=True, sharex=True)
#         plot_epsilon_grid(fairness_metric, dataset, [1, 3, 9, 10_000], axes.ravel(), q=20, take_top=None, plot_only_baselines=plot_only_baselines)
#         budget_to_ax_id = {1.0: 0, 3.0: 1, 9.0: 2, 10000.0: 3}

#     fig.tight_layout()

#     plot_df, budgets = get_optuna_results(optuna_path, dataset, fairness_metric, budget_to_ax_id)

#     for b in budgets:
#         _df = plot_df.query(f'budget == {b}').copy(deep=True)
#         log(b, len(_df))
#         _df = plot_pareto(_df, 'test_error', 'test_disparity')
#         g = sns.lineplot(data=_df, x="error_mid", y="test_disparity", ax=axes.ravel()[budget_to_ax_id[b]], label="FairPATE", color='magenta', alpha=0.5)
#         g.set_xlabel("Misclassification Error")
#         g.set_ylabel(pretty_print(fairness_metric))

#     return fig, plot_df

def plot_results(dataset, fmetric, plot_only_baselines = False, results_df=None, optuna_path="./paper_results", xlims=None, ylims=None, rotate_y=None, model=None, has_fairness_threshold=False, skip_others = False, additional_specs=None, label=None, ylabel=None, budget_to_ax_id=None, q=10):
    if skip_others:
        def dummy_plot_grid(*args, **kwargs):
            for ax, eps in zip(args[3], args[2]):
                if eps == 10_000:
                    eps = np.inf
                ax.set_title("$\\varepsilon=" + str(eps) + "$")
        plot_grid =  dummy_plot_grid
    else:
        plot_grid = plot_epsilon_grid

    if fmetric == 'EqualityOfOdds':
        fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
        plot_grid(fmetric, dataset, [0.5, 1, 3], axes.ravel(), q=q, take_top=None, plot_only_baselines=plot_only_baselines)
    elif fmetric == 'DemParity':
        fig, axes = plt.subplots(2, 2, figsize=(8, 8), sharey=True, sharex=True, layout='constrained')
        plot_grid(fmetric, dataset, [1, 3, 9, 10_000], axes.ravel(), q=q, take_top=None, plot_only_baselines=plot_only_baselines)
    else:
        fig, axes = plt.subplots(2, 2, figsize=(8, 8), sharey=True, sharex=True, layout='constrained')
        dummy_plot_grid(fmetric, dataset, [1, 3, 9, 10_000], axes.ravel(), q=q, take_top=None, plot_only_baselines=plot_only_baselines)
    
    if budget_to_ax_id is None:
        if fmetric == 'EqualityOfOdds':
            budget_to_ax_id = {0.5:0, 1.0: 1, 3.0: 2}
        elif fmetric == 'DemParity':
            budget_to_ax_id = {1.0: 0, 3.0: 1, 9.0: 2, 10000.0: 3}
        else:
            budget_to_ax_id = {1.0: 0, 3.0: 1, 9.0: 2, 10000.0: 3}
        

    def plot_spec(model, has_fairness_threshold, optuna_path=None, results_df=None, fairness_metric=None, label=None, color=None, ylabel=None):
        
        if color is None:
            color = 'magenta'
            
        if fairness_metric is None:
            fairness_metric = fmetric

        if optuna_path is None and results_df is None:
            raise ValueError("Either optuna_path or results_df must be provided")
        elif results_df is not None:
            plot_df = results_df
            budgets = results_df["budget"].value_counts().sort_index().index.tolist()
        elif optuna_path is not None:
            plot_df, budgets = get_optuna_results(optuna_path, dataset, fairness_metric, budget_to_ax_id, model=model, has_fairness_threshold=has_fairness_threshold)
        

        dfs_pareto = {b: plot_pareto(
            df=plot_df.query(f'budget == {b}').copy(deep=True),
            x_col='test_error', 
            y_col='test_disparity', 
            z_col='test_rejection' if (model is not None and "IPP" in model) else None, 
            q=q)  for b in budgets}

        if model is not None and "IPP" in model:
            all_coverage = np.concatenate([dfs_pareto[b]['coverage_mid'].values for b in budgets])

            norm = plt.Normalize(0.6, 1.0)
            sm = plt.cm.ScalarMappable(cmap=sns.cubehelix_palette(as_cmap=True), norm=norm)
            sm.set_array([])
        else:
            sm = None

        for b in budgets:
            if model is not None and "IPP" in model:
                g = sns.lineplot(data=dfs_pareto[b], x="test_error_mid", y="test_disparity", hue="test_coverage_mid", hue_norm=norm, palette=sns.cubehelix_palette(as_cmap=True), ax=axes.ravel()[budget_to_ax_id[b]], legend=False)
            else:
                g = sns.lineplot(data=dfs_pareto[b], x="test_error_mid", y="test_disparity", ax=axes.ravel()[budget_to_ax_id[b]], label=label, color=color, alpha=0.5)
            g.set_xlabel("Classification Error")
            g.set_ylabel(ylabel if ylabel is not None else pretty_print(fairness_metric))
            # Rotate the y-axis tick labels
            if rotate_y is None:
                g.tick_params(axis='y', labelrotation=90)
            else:
                g.tick_params(axis='y', labelrotation=rotate_y)
            g.tick_params(axis='x', labelrotation=0)
            # Set the y-label and position it on top of the axis
            # g.set_ylabel('Disparity', rotation=0)
            # g.yaxis.set_label_coords(-0.05, 1.01)

            if xlims is not None:
                g.set_xlim(*xlims)

            if ylims is not None:
                g.set_ylim(*ylims)
            

        for ax_id, ax in enumerate(axes.ravel()):
            if fairness_metric == 'EqualityOfOdds':
                if ax_id != 1:
                    ax.legend().remove()
                else:
                    ax.legend(loc='best')

            else:
                if ax_id != 3:
                    ax.legend().remove()
                else:
                    ax.legend(loc='best')
        return plot_df, sm
                
    # fig.tight_layout()
    plot_df, sm = plot_spec(model, has_fairness_threshold, optuna_path=optuna_path, results_df=results_df, label=label, ylabel=ylabel)

    if model is not None and "IPP" in model:
        # place outside of plot
        fig.colorbar(sm, label="Coverage", ax=axes[1, :], location='bottom', pad=0.0)
    
    fig.savefig(f"paper_figures/{dataset}_{fmetric}.pdf", bbox_inches='tight')

    if additional_specs is not None:
        for spec in additional_specs:
            plot_spec(spec['model'], spec['has_fairness_threshold'], spec['optuna_path'], fairness_metric=spec['fairness_metric'] if 'fairness_metric' in spec else None, label=spec['label'] if 'label' in spec else None, **spec['kwargs'])

        return fig
    else:
        return fig, plot_df
        

def epsilon_formater(x):
    if isinstance(x, str):
        return x
    elif x == 10000.0:
        return "$\\infty$"
    else:
        return "$" + str(x) + "$"
    
def model_formatter(x):
    if x == 'FairPATE':
        return "\\textsf{" + x + "}" 
    else:
        x0 = x
        x = x.replace("_", "\\textsubscript{")
        if x != x0:
            x = x + "}"
        return x
    

def prepare_table(df):
    formatter = {'error_mid': 'Classification Error', 'test_disparity': 'Disparity', 'model': 'Model', 'budget': '$\\varepsilon$'}

    table_df = df.copy(deep=True)

    table_df["model"] = table_df["model"].apply(model_formatter)
    # table_df["budget"] = table_df["budget"].apply(lambda x: "$" + str(x) + "$")
    # table_df = table_df

    table_df = (table_df
    .groupby(['budget', 'model']).agg({'error_mid': ['median', 'mean', 'std'], 'test_disparity': ['median', 'mean', 'std']})
    .rename(columns=formatter)
    .rename(columns={'median': 'Median', 'mean': 'Mean', 'std': 'Std'})
    .rename_axis(index=formatter)
    .rename(index=epsilon_formater)
    #  .sort_values("Model", key=lambda x: x.str.len(), ascending=False)
    )
    styler = (table_df
    .style.format(escape="latex", precision=3)
    # .highlight_min(
    #     props='bfseries:;'
    # )
    #  .apply(highlight_max, groupby=['$\\varepsilon$', 'Model'])
    .to_latex(hrules=True, multicol_align='c') 
    )

    styler = styler.replace("\n\multirow", "\n\\midrule\n\multirow").replace("\midrule\n\midrule", "\midrule")
    # print(table_df)
    # print(table_df)
    pyperclip.copy(styler)
    return table_df