import matplotlib.pyplot as plt
import numpy as np
from .utils import compute_cdf

plt.rcParams.update({'font.size': 20})

linestyles = ['-', '--', '-.', ':', (0, (3, 5, 1, 5)), (1, (3, 5, 1, 5)), (2, (3, 5, 1, 5))]


# # # CALIBRATION # # #

def plot_calibration(results, is_binary, algorithms, datasets, folder='figures', num_groups=10):
    plt.rcParams.update({'font.size': 14})
    fig, axs = plt.subplots(1, len(datasets), figsize=(20, 3))

    markers = ['o', 's', 'v', 'D', 'P', 'X', 'H']

    for i, dataset in enumerate(datasets):
        instance = results[dataset]['instance']
        y = instance['y_test']
        for j, algo_name in enumerate(algorithms):
            std = results[dataset][algo_name]['std']
            pred = results[dataset][algo_name]['pred']
            # Get num_groups groups of equal size based on std
            pred_std, empirical_std = [], []
            for k in range(num_groups):
                criteria = (std >= np.percentile(std, 100/num_groups*k)) & (std < np.percentile(std, 100/num_groups*(k+1)))
                if sum(criteria) == 0: continue
                pred_std.append(np.mean(std[criteria]))                
                if is_binary: # Binary
                    pred = np.round(pred)
                    std_y0, std_y1 = None, None
                    if (pred[criteria] == 0).sum() >= 0:
                        std_y0 = np.std(y[criteria][pred[criteria] == 0])
                    if (pred[criteria] == 1).sum() >= 0:
                        std_y1 = np.std(y[criteria][pred[criteria] == 1])
                    std_mean = np.mean([x for x in [std_y0, std_y1] if x is not None])
                    empirical_std.append(std_mean)
                else: # Regression
                    empirical_std.append(np.std(y[criteria] - pred[criteria]))

            axs[i].plot(pred_std, empirical_std, marker=markers[j], label=algo_name)
            
        # get xlim and ylim
        min_val = min(axs[i].get_xlim()[0], axs[i].get_ylim()[0])
        max_val = max(axs[i].get_xlim()[1], axs[i].get_ylim()[1])
        axs[i].plot([min_val, max_val], [min_val, max_val], alpha=.5, linestyle='--', color='black')
        axs[i].set_title(dataset)


        axs[i].set_xlabel('Predicted')
        if i == 0: axs[i].set_ylabel('Empirical')
        # Decrease size of tick font
        #axs[i].tick_params(axis='both', which='major', labelsize=10)
        #else: axs[i].set_yticklabels([])

    bbox_to_anchor = (1,-.2) if is_binary else (1,-.2)
    plt.legend(fancybox=True, bbox_to_anchor=bbox_to_anchor, ncol=4)
    plt.suptitle(f'Calibration for {num_groups} Groups', y=1.1, fontsize=25)

    type = 'binary' if is_binary else 'regression'
    filename = folder + '/calibration_' + type + f'_{num_groups}_groups.pdf'
    plt.savefig(filename, dpi=1000, bbox_inches="tight")
    #plt.show()
    plt.clf()

# # # CONSISTENCY # # #

def plot_consistency(results, is_binary, algorithms, datasets, parameter_name, folder='figures'):
    # Boxplot where each box is a different algorithm and each plot is a different dataset
    # The boxplots are the standard deviation of the predictions across parameters

    # Nice colorblind color palette
    colors = ['#1b9e77', '#d95f02', '#7570b3', '#e7298a', '#66a61e', '#e6ab02', '#a6761d', '#666666']

    fig, axs = plt.subplots(1, len(datasets), figsize=(6*len(datasets), len(datasets)))
    for i, dataset in enumerate(datasets):
        parameters = sorted(results[dataset][list(algorithms.keys())[0]].keys())
        data = []
        for j, name in enumerate(algorithms.keys()):
            uncertainties = np.array([results[dataset][name][param][0] for param in parameters])
            std = np.std(uncertainties, axis=0)
            data.append(std)
            axs[i].boxplot(std, positions=[j], patch_artist=True, boxprops=dict(facecolor=colors[j]))
        # Remove xlabels
        axs[i].set_xticks([])
        if i == 0:
            # Smaller font size for y-axis
            axs[i].set_ylabel('Standard Deviation')
            axs[i].tick_params(axis='both', which='major', labelsize=10)
        axs[i].set_title(dataset)
    # Add legend at bottom
    from matplotlib import patches as mpatches
    # Set title higher to avoid overlap axis titles
    lookup = {'max_depth' : 'Depths', 'gamma' : 'Reduction Thresholds'}
    fig.suptitle(f'Standard Deviation of Uncertainty Across {lookup[parameter_name]}', y=1.05)
    handles = [mpatches.Patch(color=colors[i], label=name) for i, name in enumerate(algorithms.keys())]
    plt.legend(handles=handles, bbox_to_anchor=(-len(datasets)/3, -0.1), ncol=len(algorithms))
    type = 'binary' if is_binary else 'regression'
    filename = f'{folder}/consistency_boxplot_{parameter_name}_{type}.pdf'
    plt.savefig(filename, dpi=1000, bbox_inches="tight")
    plt.clf()

def plot_consistency_with_range(results, is_binary, algorithms, datasets, parameter_name, folder='figures'):
    num_datasets = len(datasets)
    num_algos = len(algorithms)
    fig, axs = plt.subplots(num_datasets, num_algos, figsize=(17, 2*num_datasets), squeeze=False) 
    colors = ['blue', 'red', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'] 

    for i, dataset in enumerate(datasets):
        parameters = sorted(results[dataset][list(algorithms.keys())[0]].keys())
        individuals_to_plot = np.random.choice(range(len(results[dataset][list(algorithms.keys())[0]][parameters[0]][0])), 100, replace=False)

        for j, name in enumerate(algorithms.keys()):
            print(name)
            uncertainties = []
            variances = []
            for idx in individuals_to_plot:
                individual_uncertainties = np.array([results[dataset][name][param][0][idx] for param in parameters])
                min_uncertainty = np.min(individual_uncertainties)
                max_uncertainty = np.max(individual_uncertainties)
                std = np.std(individual_uncertainties, axis=0)
                variances.append(std)
                uncertainties.append((min_uncertainty, max_uncertainty))
                
            for std, (min_uncertainty, max_uncertainty) in zip(variances, uncertainties):
                if max_uncertainty - min_uncertainty > 0.001:
                    axs[i,j].vlines(std, min_uncertainty, max_uncertainty, color=colors[j], alpha=0.4)
                else:
                    axs[i,j].scatter(std, min_uncertainty, color=colors[j], alpha=0.4)
            
            axs[i,j].set_ylim([0, 1])
            axs[i,j].set_xlim([0, .5])
            if i == 0:
                axs[i,j].set_title(name)

            if j == 0:
                axs[i,j].set_ylabel(dataset, fontsize=12)
            else:
                axs[i,j].set_yticks([])

            if i != num_datasets-1:
                axs[i,j].set_xticks([])

            axs[i,j].tick_params(axis='both', which='major', labelsize=10)
    
    # Add y label for all plots
    fig.text(0.04, 0.5, 'Min/Max Uncertainty', va='center', rotation='vertical')
    # Add x label for all plots
    fig.text(0.5, 0.04, 'SD of Uncertainty', ha='center')
    
    plt.legend(fancybox=True, bbox_to_anchor=(1,-.2), ncol=4)
    plot_type = 'binary' if is_binary else 'regression'
    filename = folder + f'/consistency_range_{parameter_name}_{plot_type}.pdf'
    plt.savefig(filename, dpi=1000, bbox_inches="tight")
    plt.clf()

def plot_consistency_old(results, is_binary, algorithms, datasets, folder='figures', idx=42):
    fig, axs = plt.subplots(1, len(datasets), figsize=(20, 3))    

    for i, dataset in enumerate(datasets):
        max_depths = sorted(results[dataset][list(algorithms.keys())[0]].keys())
        for j, name in enumerate(algorithms.keys()):
            std= np.array([results[dataset][name][max_depth][0][idx] for max_depth in max_depths])
            axs[i].plot(max_depths, std, label=name, linestyle=linestyles[j], linewidth=5)
        if i == 0: axs[i].set_ylabel('Uncertainty')

        axs[i].tick_params(axis='both', which='major', labelsize=10)
        axs[i].set_xlabel('Depth')
        axs[i].set_title(dataset)

    plt.legend(fancybox=True, bbox_to_anchor=(1,-.2), ncol=4)
    type = 'binary' if is_binary else 'regression'
    filename = folder + '/consistency_' + type + '.pdf'
    plt.savefig(filename, dpi=1000, bbox_inches="tight")
    #plt.show()
    plt.clf()

# # # ABSTENTION # # #

def plot_abstention(results, algorithms, datasets, percentiles, metrics, metric_name='Statistical Parity', folder='figure'):
    fig, axs = plt.subplots(1, len(datasets), figsize=(len(datasets)*4, 4))

    for i, dataset in enumerate(datasets):
        instance = results[dataset]['instance']
        for j, algo_name in enumerate(algorithms):
            output = results[dataset][algo_name]
            metric_values = []
            for percentile in percentiles:
                std = output['std']
                include = std <= np.percentile(std, percentile)
                metric_value = metrics[metric_name](
                    output['pred'][include], instance['y_test'][include], instance['group_test'][include] == 1, run_checks=False
                )
                metric_values.append(metric_value)
            # increase line thickness
            axs[i].plot(1-percentiles/100, metric_values, label=algo_name, linestyle=linestyles[j], linewidth=3)
        if i == 0: axs[i].set_ylabel(metric_name)
        axs[i].tick_params(axis='both', which='major', labelsize=10)
        axs[i].set_title(dataset)
        axs[i].set_xlabel('Abstention Rate')

    plt.legend(fancybox=True, bbox_to_anchor=(1,-.2), ncol=4)
    filename = f'{folder}/abstention_{metric_name}.pdf'
    plt.savefig(filename, dpi=1000, bbox_inches="tight")
    #plt.show()
    plt.clf()

# # # REGRESSION FAIRNESS # # #

def plot_regression_fairness(results, datasets, algorithms, folder='figures'):
    fig, axs = plt.subplots(1, len(datasets), figsize=(20, 3))

    for dataset_num, dataset in enumerate(datasets):
        instance = results[dataset]['instance']
        min_z, max_z = min(instance['y_test']), max(instance['y_test'])
        z = np.linspace(min_z, max_z, 1000)

        colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
        for i, algorithm in enumerate(algorithms):
            output = results[dataset][algorithm]
            cdfs = []
            for j, group in enumerate(np.unique(instance['group_test'])):
                cdf = compute_cdf(output, z, instance['group_test'] == group)
                label = algorithm if j == 0 else None
                axs[dataset_num].plot(z, cdf, label=label, color=colors[i], linestyle=linestyles[j])
                cdfs += [cdf]
        axs[dataset_num].tick_params(axis='both', which='major', labelsize=10)
        axs[dataset_num].set_title(dataset)
        axs[dataset_num].set_xlabel(r'$y$')
        if dataset_num == 0:
            axs[dataset_num].set_ylabel('CDF')

    plt.legend(fancybox=True, bbox_to_anchor=(1,-.2), ncol=len(algorithms))
    plt.savefig(f'{folder}/regression_fairness.pdf', dpi=1000, bbox_inches='tight')
    #plt.show()
    plt.clf()