import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import pickle

from exp_settings import all_exp_settings
from data_settings import all_settings

matplotlib.rcParams.update({'font.size': 20})

from directories import results_dir, plots_dir

approach_names = {'baseline_plain_clean': 'Clean Labels', 'baseline_plain': 'Standard', \
                  'baseline_filt': 'Filter', 'baseline_sln': 'SLN', 'baseline_sln_filt': 'SLN + Filter', \
                  'baseline_transition': 'Transition', 'baseline_fair_reweight': 'Fair Reweight', \
                  'baseline_js_loss': 'JS Loss', 'baseline_transit_conf': 'CSIDN', \
                  'baseline_fair_gpl': 'Fair GPL', 'anchor': 'Anchor Only', 'proposed1': 'Proposed'}
data_names_official = {'synth_random': 'Synthetic', 'MIMIC-ARF-random': 'MIMIC-ARF', 'MIMIC-Shock-random': 'MIMIC-Shock', \
                       'adult-random': 'Adult', 'compas-random': 'COMPAS', \
                       'synth_feat1': 'Synthetic', 'MIMIC-ARF-feat1': 'MIMIC-ARF', 'MIMIC-Shock-feat1': 'MIMIC-Shock', \
                       'adult-feat1': 'Adult', 'compas-feat1': 'COMPAS', \
                       'synth_feat2': 'Synthetic', 'MIMIC-ARF-feat2': 'MIMIC-ARF', 'MIMIC-Shock-feat2': 'MIMIC-Shock', \
                       'adult-feat2': 'Adult', 'compas-feat2': 'COMPAS'}


#########################################################################################
'''
print out results nicely
'''
def postprocess_results(dataset, approaches, date, val_gt, show=True):
    if show:
        print(dataset)

    results = []
    boot_keys = ['auroc', 'aupr', 'aueo', 'hm']
    for i in range(len(boot_keys)):
        results.append(np.zeros((len(approaches), 3)))
        
    #non-bootstrapped results
    for i in range(len(approaches)):
        approach = approaches[i]
        file_name = results_dir + dataset + '/' + date + '_' + approach + '_' + str(val_gt) + '.pkl'
        file_handle = open(file_name, 'rb')
        res = pickle.load(file_handle)
        file_handle.close()
        if show:
            print(approach, 'auroc', res['auroc'], 'aupr', res['aupr'], 'aueo', res['aueo'], 'hm', res['hm'])
  
    '''#bootstrapped results  
    table_prints = []
    for i in range(len(approaches)):
        approach = approaches[i]
        file_name = results_dir + dataset + '/' + date + '_' + approach + '_' + str(val_gt) + '.pkl'
        file_handle = open(file_name, 'rb')
        res = pickle.load(file_handle)
        file_handle.close()
        boot_res = res['boot_res'][1]
        boot_print = approach + ' '
        table_print = approach + ' '
        for j in range(len(boot_keys)):
            key = boot_keys[j]
            res_j = boot_res[key].reshape(-1)
            results[j][i, :] = np.array([res_j[1], res_j[1] - res_j[0], res_j[2] - res_j[1]]).reshape(-1)
            boot_print += key + ': ' + '%.3f (%.3f, %.3f)' % (res_j[1], res_j[0], res_j[2]) + ' '
            if key != 'aupr':
                table_print +=' & ' + ('%.2f' % res_j[1])[1:] + '(' + ('%.2f' % res_j[0])[1:] + ', ' + ('%.2f' % res_j[2])[1:] + ')'
        if show:
            print(boot_print)
        table_prints.append(table_print)
    
    if show:
        for line in table_prints:
            print(line)'''
                
    return results
    
    
'''
make bar graph out of results
'''
def make_bar_graph(dataset, approaches, date, val_gt, show=True):
    boot_keys = ['auroc', 'aueo', 'hm']
  
    #make bar graph out of bootstrapped results 
    plt.figure(figsize=(16, 8)) 
    for i in range(len(approaches)):
        approach = approaches[i]
        file_name = results_dir + dataset + '/' + date + '_' + approach + '_' + str(val_gt) + '.pkl'
        file_handle = open(file_name, 'rb')
        res = pickle.load(file_handle)
        file_handle.close()
        boot_res = res['boot_res'][1]
        
        bar_pos = np.arange(len(boot_keys)) * (len(approaches)+1) + i
        bar_heights = []
        error_lower = []
        error_upper = []
        for j in range(len(boot_keys)):
            key = boot_keys[j]
            res_j = boot_res[key].reshape(-1)
            bar_heights.append(res_j[1])
            error_lower.append(res_j[1] - res_j[0])
            error_upper.append(res_j[2] - res_j[1])
        plt.bar(bar_pos, bar_heights, yerr=np.array([error_lower, error_upper]), capsize=2, label=approach_names[approach])
    
    xtick_labs = ['AUROC', 'AUEO', 'HM']
    xtick_pos = np.arange(len(boot_keys)) * (len(approaches)+1) + ((len(approaches)- 1)/2)
    plt.xticks(xtick_pos, xtick_labs)
    plt.title(data_names_official[dataset])
    plt.ylim(0.5, 1)
    plt.ylabel('Value')
    plt.legend(loc='upper left')
    plt.savefig(dataset + '_overall.png', )
                

'''
make bar graph out of results
'''
def make_bar_graph_split(dataset, approaches, date, val_gt, show=True):
    #print(dataset)
    boot_keys = ['auroc', 'aueo', 'hm']
    metric_labs = ['AUROC', 'AUEO', 'HM']
    matplotlib.rcParams.update({'font.size': 30})
  
    #make bar graph out of bootstrapped results 
    plt.figure(figsize=(16, 8)) 
    bounds = {'auroc': [1000, -1000], 'aueo': [1000, -1000], 'hm': [1000, -1000]}
    for i in range(len(approaches)):
        approach = approaches[i]
        #print(approach)
        file_name = results_dir + dataset + '/' + date + '_' + approach + '_' + str(val_gt) + '.pkl'
        file_handle = open(file_name, 'rb')
        res = pickle.load(file_handle)
        file_handle.close()
        boot_res = res['boot_res'][1]
        
        for j in range(len(boot_keys)):
            key = boot_keys[j]
            res_j = boot_res[key].reshape(-1)
            if res_j[1] < bounds[key][0]:
                bounds[key][0] = res_j[1]
            if res_j[1] > bounds[key][1]:
                bounds[key][1] = res_j[1]
            plt.subplot(1, len(boot_keys), j+1)
            appr_name = approach_names[approach]
            if approach in ['proposed1', 'baseline_plain_clean']:
                plt.bar(i, res_j[1], yerr=np.array([[res_j[1] - res_j[0]], [res_j[2] - res_j[1]]]), capsize=2, label=appr_name)
            else:
                plt.bar(i, res_j[1], yerr=np.array([[res_j[1] - res_j[0]], [res_j[2] - res_j[1]]]), capsize=2, label=appr_name, hatch='/')
    
    for i in range(len(boot_keys)):
        matplotlib.rcParams.update({'font.size': 30})
        key = boot_keys[i]
        ax = plt.subplot(1, len(boot_keys), i+1)
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        plt.ylim(max(bounds[key][0] - 0.1, 0), min(bounds[key][1] + 0.1, 1))
        plt.ylabel(metric_labs[i])
        plt.xticks([], [])
        if i == 0 and 'synth_random' in dataset:
            matplotlib.rcParams.update({'font.size': 24})
            plt.legend(loc='upper left', frameon=False)
        
    #plt.subplots_adjust(left=0.08, right=0.98, top=0.98, bottom=0.02, wspace=0.35) #20 font
    plt.subplots_adjust(left=0.1, right=0.98, top=0.95, bottom=0.03, wspace=0.45) #30 font
    #plt.subplots_adjust(left=0.13, right=0.98, top=0.95, bottom=0.02, wspace=0.65) #40 font
    plt.savefig(plots_dir + dataset + '_overall.png', )
    

def make_test_plot1(dataset, approaches, date, val_gt, show=True):
    #print(dataset)
    boot_keys = ['auroc', 'aueo', 'hm']
    metric_labs = ['AUROC', 'AUEO', 'HM']
    matplotlib.rcParams.update({'font.size': 30})
    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
  
    #make bar graph out of bootstrapped results 
    plt.figure(figsize=(16, 8)) 
    bounds = {'auroc': [1000, -1000], 'aueo': [1000, -1000], 'hm': [1000, -1000]}
    for i in range(len(approaches)):
        approach = approaches[i]
        file_name = results_dir + dataset + '/' + date + '_' + approach + '_' + str(val_gt) + '.pkl'
        file_handle = open(file_name, 'rb')
        res = pickle.load(file_handle)
        file_handle.close()
        boot_res = res['boot_res'][1]
        
        for j in range(len(boot_keys)):
            key = boot_keys[j]
            res_j = boot_res[key].reshape(-1)
            if res_j[1] < bounds[key][0]:
                bounds[key][0] = res_j[1]
            if res_j[1] > bounds[key][1]:
                bounds[key][1] = res_j[1]
            ax = plt.subplot(1, len(boot_keys), j+1)
            appr_name = approach_names[approach]
            if approach in ['proposed1']:
                plt.bar(i-1, res_j[1], yerr=np.array([[res_j[1] - res_j[0]], [res_j[2] - res_j[1]]]), capsize=2, label=appr_name, color=colors[i])
            elif approach in ['baseline_plain', 'baseline_plain_clean']:
                plt.axhline(res_j[1], label=appr_name, color=colors[i], linewidth=2)
                plt.fill_between([-1, len(approaches)-1.5], [res_j[2], res_j[2]], [res_j[0], res_j[0]], color=colors[i], alpha=0.2)
            else:
                plt.bar(i-1, res_j[1], yerr=np.array([[res_j[1] - res_j[0]], [res_j[2] - res_j[1]]]), capsize=2, label=appr_name, hatch='/', color=colors[i])
    
    for i in range(len(boot_keys)):
        matplotlib.rcParams.update({'font.size': 30})
        key = boot_keys[i]
        ax = plt.subplot(1, len(boot_keys), i+1)
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        plt.ylim(max(bounds[key][0] - 0.1, 0), min(bounds[key][1] + 0.1, 1))
        plt.ylabel(metric_labs[i])
        plt.xlim(-0.7, len(approaches)-2.3)
        plt.xticks([], [])
        if i == 0 and 'synth_random' in dataset:
            matplotlib.rcParams.update({'font.size': 24})
            plt.legend(loc='upper left', frameon=False)
        
    #plt.subplots_adjust(left=0.08, right=0.98, top=0.98, bottom=0.02, wspace=0.35) #20 font
    plt.subplots_adjust(left=0.1, right=0.98, top=0.95, bottom=0.03, wspace=0.45) #30 font
    #plt.subplots_adjust(left=0.13, right=0.98, top=0.95, bottom=0.02, wspace=0.65) #40 font
    plt.savefig(plots_dir + dataset + '_overall_test1.png', )
    

def make_test_plot2(dataset, approaches, date, val_gt, show=True):
    #print(dataset)
    boot_keys = ['auroc', 'aueo', 'hm']
    metric_labs = ['AUROC', 'AUEO', 'HM']
    matplotlib.rcParams.update({'font.size': 30})
    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
  
    #make bar graph out of bootstrapped results 
    plt.figure(figsize=(16, 8)) 
    bounds = {'auroc': [1000, -1000], 'aueo': [1000, -1000], 'hm': [1000, -1000]}
    for i in range(len(approaches)):
        approach = approaches[i]
        file_name = results_dir + dataset + '/' + date + '_' + approach + '_' + str(val_gt) + '.pkl'
        file_handle = open(file_name, 'rb')
        res = pickle.load(file_handle)
        file_handle.close()
        boot_res = res['boot_res'][1]
        
        error_bars = []
        points = []
        for j in range(len(boot_keys)):
            key = boot_keys[j]
            res_j = boot_res[key].reshape(-1)
            if res_j[1] < bounds[key][0]:
                bounds[key][0] = res_j[1]
            if res_j[1] > bounds[key][1]:
                bounds[key][1] = res_j[1]
            error_bars.append(np.array([[res_j[1] - res_j[0]], [res_j[2] - res_j[1]]]))
            points.append(res_j[1])
        appr_name = approach_names[approach]
        plt.errorbar(points[0], points[1], xerr=error_bars[0], yerr=error_bars[1], marker='o', color=colors[i], label=appr_name, markersize=points[2]*20)
    
    for i in range(len(boot_keys)):
        matplotlib.rcParams.update({'font.size': 30})
        key = boot_keys[i]
        ax = plt.subplot(1, 1, 1)
        if i == 0:
            ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
            plt.xlim(max(bounds[key][0] - 0.05, 0), min(bounds[key][1] + 0.05, 1))
            plt.xlabel(metric_labs[i])
        elif i == 1:
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
            plt.ylim(max(bounds[key][0] - 0.05, 0), min(bounds[key][1] + 0.05, 1))
            plt.ylabel(metric_labs[i])
        if 'synth_random' in dataset:
            matplotlib.rcParams.update({'font.size': 24})
            plt.legend(loc='upper left', frameon=False)
        
    #plt.subplots_adjust(left=0.08, right=0.98, top=0.98, bottom=0.02, wspace=0.35) #20 font
    plt.subplots_adjust(left=0.1, right=0.98, top=0.95, bottom=0.03, wspace=0.45) #30 font
    #plt.subplots_adjust(left=0.13, right=0.98, top=0.95, bottom=0.02, wspace=0.65) #40 font
    plt.savefig(plots_dir + dataset + '_overall_test2.png', )


def make_test_plot3(dataset, approaches, date, val_gt, show=True):
    #print(dataset)
    boot_keys = ['auroc', 'aueo', 'hm']
    metric_labs = ['AUROC', 'AUEO', 'HM']
    matplotlib.rcParams.update({'font.size': 30})
    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
  
    #make bar graph out of bootstrapped results 
    plt.figure(figsize=(16, 8)) 
    bounds = {'auroc': [1000, -1000], 'aueo': [1000, -1000], 'hm': [1000, -1000]}
    for i in range(len(approaches)):
        approach = approaches[i]
        file_name = results_dir + dataset + '/' + date + '_' + approach + '_' + str(val_gt) + '.pkl'
        file_handle = open(file_name, 'rb')
        res = pickle.load(file_handle)
        file_handle.close()
        boot_res = res['boot_res'][1]
        
        for j in range(len(boot_keys)):
            key = boot_keys[j]
            res_j = boot_res[key].reshape(-1)
            if res_j[1] < bounds[key][0]:
                bounds[key][0] = res_j[1]
            if res_j[1] > bounds[key][1]:
                bounds[key][1] = res_j[1]
            ax = plt.subplot(len(boot_keys), 1, j+1)
            appr_name = approach_names[approach]
            plt.errorbar(res_j[1], i/10+0.1, marker='None', xerr=np.array([[res_j[1] - res_j[0]], [res_j[2] - res_j[1]]]), color=colors[i], elinewidth=8, alpha=0.5)
            plt.errorbar(res_j[1], i/10 + 0.1, marker='o', color=colors[i], label=appr_name, markersize=9)
            if approach in ['baseline_plain', 'baseline_plain_clean']:
                plt.axvline(res_j[1], color='k', linestyle='dashed')#, linewidth=2)
    
    for i in range(len(boot_keys)):
        matplotlib.rcParams.update({'font.size': 30})
        key = boot_keys[i]
        ax = plt.subplot(len(boot_keys), 1, i+1)
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        plt.xlim(max(bounds[key][0] - 0.07, 0), min(bounds[key][1] + 0.07, 1))
        plt.xlabel(metric_labs[i])
        plt.yticks([], [])
        plt.ylim(-0.03, 0.93)
          
    if 'synth_random' in dataset:
        ax = plt.subplot(len(boot_keys), 1, 3)
        matplotlib.rcParams.update({'font.size': 24})
        plt.legend(loc='upper left', frameon=True, framealpha=0.7, bbox_to_anchor=(0, 4.3)) 
    plt.subplots_adjust(left=0.04, right=0.96, top=0.97, bottom=0.12, hspace=0.6) 
    plt.savefig(plots_dir + dataset + '_overall_test3.png', )
    
    
def make_test_plot4(dataset, approaches, date, val_gt, show=True):
    #print(dataset)
    boot_keys = ['auroc', 'aueo', 'hm']
    metric_labs = ['AUROC', 'AUEO', 'HM']
    matplotlib.rcParams.update({'font.size': 30})
    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
  
    markers = ['o', '^', '*']
    #make bar graph out of bootstrapped results 
    plt.figure(figsize=(16, 6)) 
    bounds = {'auroc': [1000, -1000], 'aueo': [1000, -1000], 'hm': [1000, -1000]}
    bar_width, dot_width = 20, 22
    for i in range(len(approaches)):
        approach = approaches[i]
        file_name = results_dir + dataset + '/' + date + '_' + approach + '_' + str(val_gt) + '.pkl'
        file_handle = open(file_name, 'rb')
        res = pickle.load(file_handle)
        file_handle.close()
        boot_res = res['boot_res'][1]
        
        for j in range(len(boot_keys)):
            key = boot_keys[j]
            res_j = boot_res[key].reshape(-1)
            if res_j[1] < bounds[key][0]:
                bounds[key][0] = res_j[1]
            if res_j[1] > bounds[key][1]:
                bounds[key][1] = res_j[1]
            #ax = plt.subplot(len(boot_keys), 1, j+1)
            appr_name = approach_names[approach]# + ' ' + metric_labs[j]
            plt.errorbar(res_j[1], i/10+0.1, marker='None', xerr=np.array([[res_j[1] - res_j[0]], [res_j[2] - res_j[1]]]), color=colors[i], elinewidth=bar_width, alpha=0.5)
            if i == 0:
                plt.errorbar(res_j[1], i/10 + 0.1, marker=markers[j], color=colors[i], label=appr_name + ' ' + metric_labs[j], markersize=dot_width)
            elif j == 0:
                plt.errorbar(res_j[1], i/10 + 0.1, marker=markers[j], color=colors[i], label=appr_name, markersize=dot_width)
            else:
                plt.errorbar(res_j[1], i/10 + 0.1, marker=markers[j], color=colors[i], markersize=dot_width)
    
    #for i in range(len(boot_keys)):
    matplotlib.rcParams.update({'font.size': 30})
    ax = plt.subplot(1, 1, 1)
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    plt.yticks([], [])
    plt.ylim(0, 0.9)
          
    if 'synth_random' in dataset:
        matplotlib.rcParams.update({'font.size': 26})
        plt.legend(loc='upper left', frameon=False, framealpha=0.7, labelspacing=0.25, bbox_to_anchor=(0.54, 0.91)) 
    plt.subplots_adjust(left=0.04, right=0.96, top=0.97, bottom=0.12, hspace=0.6) 
    plt.savefig(plots_dir + dataset + '_overall_test4.png', )
    

#########################################################################################
'''
plot performance over different conditions (overall noise rate)
'''
def plot_noise_rate(dataset, approaches, date, conditions, cond_lab, val_gt, results, offset, plot_num, exp_name, colour=None):
    evals = []
    for i in range(len(approaches)):
        evals.append(np.zeros((len(conditions), 3)))

    xpoints = np.zeros((len(conditions),))
    for i in range(len(conditions)):
        cond = conditions[i]
        for j in range(len(approaches)):
            approach = approaches[j]
            evals[j][i, 0] = np.mean(results[j][i, :])
            #print(cond, approach, results[j][i, :], np.mean(results[j][i, :]))
            evals[j][i, 1:] = np.std(results[j][i, :])
        xpoints[i] = all_exp_settings[dataset][cond+offset][cond_lab][1]
        if 'rate' in exp_name:#conditions == [1,6,10,13]:
            xpoints[i] = all_exp_settings[dataset][cond+offset][cond_lab][0]
        #print(all_exp_settings[dataset][cond+offset][cond_lab], cond, offset, cond_lab)
    #print(xpoints)

    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
    lines = ['-']
    markers = ['o', 'v', '^', '>', '<', 's', 'x', '*', 'D']
    matplotlib.rcParams.update({'font.size': 16})
    ax = plt.subplot(1, plot_num[1], plot_num[0]) #plt.subplot(1, 1, 1)
    for i in range(len(approaches)):
        use_color = colour
        if colour is None:
            use_colour = colors[i]
        #lab = data_names_official[dataset] + '\n' + approach_names[approaches[i]]
        #if i > 0:
        lab = approach_names[approaches[i]]
        plt.errorbar(xpoints, evals[i][:, 0], yerr=evals[i][:, 1:].T, marker=markers[i], markersize=10, linestyle=lines[0], color=use_color, label=lab)
        #plt.plot(xpoints, evals[i][:, 0], marker='o', linestyle=lines[i], color=use_color, \
        #         label=data_names_official[dataset] + '\n' + approach_names[approaches[i]])
        #print('noise rate', approaches[i], evals[i][:, :])
    if plot_num[0] == plot_num[1] and plot_num[1] > 1:#1:
        plt.legend(loc='lower right', frameon=False, bbox_to_anchor=(1.65, 0.1)) #bbox_to_anchor=(1.41, -0.4)
    elif plot_num[1] == 1:
        plt.legend(loc='lower right', frameon=False, bbox_to_anchor=(1.51, 0.1)) #bbox_to_anchor=(1.41, -0.4)
    matplotlib.rcParams.update({'font.size': 20})
    #plt.axvline(all_settings[dataset]['min_prop'], linestyle='dotted', color='k')
    if plot_num[0] == (plot_num[1] // 2) + 1:#plot_num[1]:
        if exp_name == 'noise_rate':#conditions == rate_cond:#[1,6,10,13]:
            plt.xlabel('Majority Noise Rate')
        elif exp_name == 'noise_rate_rand':#conditions == rand_cond and 'rand' in cond_lab:#[0,1,2,3,4]:
            plt.xlabel('Noise Rate')
        else:
            plt.xlabel('Minority Noise Rate (Majority Rate Fixed at 20%)')
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    else:
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        #plt.xticks([], [])
    if plot_num[0] == 1:#(plot_num[1] // 2) + 1:
        plt.ylabel('Harmonic Mean of AUROC and AUEOC')
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    if plot_num[1] > 1:
        plt.subplots_adjust(right=0.9, top=0.92, left=0.05, bottom=0.15, wspace=0.2)
        plt.title(data_names_official[dataset])
    else:
        plt.subplots_adjust(right=0.71, top=0.92, left=0.13, bottom=0.15)
    if exp_name == 'noise_rate':#conditions == rate_cond:#[1,6,10,13]:
        plt.savefig(plots_dir + 'noise_rate_' + str(val_gt) + '.png', )
    elif exp_name == 'noise_rate_rand':#conditions == rand_cond and 'rand' in cond_lab:#[0,1,2,3,4]:
        plt.savefig(plots_dir + 'random_' + str(val_gt) + '.png', )
    else:
        plt.savefig(plots_dir + 'noise_disp_' + str(val_gt) + '.png', )
    
    
'''
plot performance over different conditions (distribution within anchor set)
'''
def plot_anc_distr(dataset, approaches, date, conditions, cond_lab, val_gt, results, offset, plot_num, colour=None):
    evals = []
    for i in range(len(approaches)):
        evals.append(np.zeros((len(conditions), 3)))

    xpoints = np.zeros((len(conditions),))
    for i in range(len(conditions)):
        cond = conditions[i]
        for j in range(len(approaches)):
            approach = approaches[j]
            evals[j][i, 0] = np.mean(results[j][i, :])
            evals[j][i, 1:] = np.std(results[j][i, :])
        xpoints[i] = min(all_exp_settings[dataset][cond+offset][cond_lab][1]*all_settings[dataset]['min_prop']*10, 1)
        if 'MIMIC' in dataset:
            xpoints[i] = min(all_exp_settings[dataset][cond+offset][cond_lab][1]*all_settings[dataset]['min_prop']*50, 1)
        if 'adult' in dataset:# or 'compas' in dataset:
            xpoints[i] = min(all_exp_settings[dataset][cond+offset][cond_lab][1]*all_settings[dataset]['min_prop']*20, 1)

    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
    lines = ['--', '-']
    markers = ['o', 'v', '^', '>', '<', 's', 'x', '*', 'D']
    matplotlib.rcParams.update({'font.size': 16})
    ax = plt.subplot(1, plot_num[1], plot_num[0]) #plt.subplot(1, 1, 1)
    for i in range(len(approaches)):
        use_color = colour
        if colour is None:
            use_colour = colors[i]
        lab = approach_names[approaches[i]]
        plt.errorbar(xpoints, evals[i][:, 0], yerr=evals[i][:, 1:].T, marker=markers[i], markersize=10, linestyle=lines[0], color=colors[i], label=lab)
        #plt.plot(xpoints, evals[i][:, 0], marker='o', linestyle=lines[i], color=use_color, \
        #         label=data_names_official[dataset] + '\n' + approach_names[approaches[i]])
        #print('bias', approaches[i], evals[i][:, 0])
    if plot_num[0] == plot_num[1]:#1:
        plt.legend(loc='lower right', frameon=False, bbox_to_anchor=(1.65, 0.1)) #bbox_to_anchor=(1.41, -0.4)
    matplotlib.rcParams.update({'font.size': 20})
    plt.axvline(all_settings[dataset]['min_prop'], linestyle='dotted', color='k')
    if plot_num[0] == (plot_num[1] // 2) + 1:#plot_num[1]:
        plt.xlabel('Proportion Minority')
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    else:
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        #plt.xticks([], [])
    if plot_num[0] == 1:#(plot_num[1] // 2) + 1:
        plt.ylabel('Harmonic Mean of AUROC and AUEOC')
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    plt.title(data_names_official[dataset])
    plt.subplots_adjust(right=0.9, top=0.92, left=0.05, bottom=0.15, wspace=0.2)
    plt.savefig(plots_dir + 'anc_distr_' + str(val_gt) + '.png', )


'''
plot performance over different conditions (size of anchor set)
'''
def plot_anc_size(dataset, approaches, date, conditions, cond_lab, val_gt, res, plot_num, colour=None):
    evals = []
    for i in range(len(approaches)):
        evals.append(np.zeros((len(conditions), 3)))

    xpoints = np.zeros((len(conditions),))
    for i in range(len(conditions)):
        cond = conditions[i]
        for j in range(len(approaches)):
            approach = approaches[j]
            evals[j][i, 0] = np.mean(res[j][i, :])
            evals[j][i, 1:] = np.std(res[j][i, :])
        xpoints[i] = all_exp_settings[dataset][cond][cond_lab][1]

    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
    lines = ['--', '-']
    markers = ['o', 'v', '^', '>', '<', 's', 'x', '*', 'D']
    matplotlib.rcParams.update({'font.size': 16})
    ax = plt.subplot(1, plot_num[1], plot_num[0])
    for i in range(len(approaches)):
        use_color = colour
        if colour is None:
            use_color = colors[i]
        lab = approach_names[approaches[i]] #data_names_official[dataset] + '\n' + 
        if i > 0:
            lab = approach_names[approaches[i]]
        plt.errorbar(xpoints, evals[i][:, 0], yerr=evals[i][:, 1:].T, marker=markers[i], markersize=10, linestyle=lines[0], color=colors[i], label=lab)
        #plt.errorbar(xpoints, evals[i][:, 0], marker='o', linestyle=lines[i], color=use_color, \
        #             label=data_names_official[dataset] + '\n' + approach_names[approaches[i]])
        #print(dataset, 'size', approaches[i], evals[i])
    if plot_num[0] == plot_num[1]:#1:
        plt.legend(loc='lower right', frameon=False, bbox_to_anchor=(1.65, 0.1)) #bbox_to_anchor=(1.41, -0.4)
    matplotlib.rcParams.update({'font.size': 20})
    if plot_num[0] == (plot_num[1] // 2) + 1:#plot_num[0] == plot_num[1]:
        plt.xlabel('Proportion of Training Data in Anchor Set')
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    else:
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        #plt.xticks([], [])
    if plot_num[0] == 1:#== (plot_num[1] // 2) + 1:
        plt.ylabel('Harmonic Mean of AUROC and AUEOC')
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    plt.title(data_names_official[dataset])
    plt.subplots_adjust(right=0.9, top=0.92, left=0.05, bottom=0.15, wspace=0.2)
    plt.savefig(plots_dir + 'anc_size_' + str(val_gt) + '.png', )


'''
plot performance over different conditions (size of minority population)
'''
def plot_minprop(dataset, approaches, date, conditions, cond_lab, val_gt, res, offset, plot_num, colour=None):
    evals = []
    for i in range(len(approaches)):
        evals.append(np.zeros((len(conditions), 3)))

    xpoints = np.zeros((len(conditions),))
    for i in range(len(conditions)):
        cond = int(conditions[i] + offset)
        for j in range(len(approaches)):
            approach = approaches[j]
            evals[j][i, 0] = np.mean(res[j][i, :])
            evals[j][i, 1:] = np.std(res[j][i, :])
        print(dataset, cond, cond_lab, conditions, all_exp_settings[dataset][cond][cond_lab])
        xpoints[i] = all_exp_settings[dataset][cond][cond_lab]#[1]

    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
    lines = ['--', '-']
    markers = ['o', 'v', '^', '>', '<', 's', 'x', '*', 'D']
    matplotlib.rcParams.update({'font.size': 16})
    ax = plt.subplot(1, plot_num[1], plot_num[0])
    for i in range(len(approaches)):
        use_color = colour
        if colour is None:
            use_color = colors[i]
        lab = approach_names[approaches[i]] 
        if i > 0:
            lab = approach_names[approaches[i]]
        plt.errorbar(xpoints, evals[i][:, 0], yerr=evals[i][:, 1:].T, marker=markers[i], markersize=10, linestyle=lines[0], color=colors[i], label=lab)
    #if plot_num[0] == plot_num[1]:
    #    plt.legend(loc='lower right', frameon=False, bbox_to_anchor=(1.65, 0.1)) 
    if plot_num[0] == plot_num[1] and plot_num[1] > 1:#1:
        plt.legend(loc='lower right', frameon=False, bbox_to_anchor=(1.65, 0.1)) #bbox_to_anchor=(1.41, -0.4)
    elif plot_num[1] == 1:
        plt.legend(loc='lower right', frameon=False, bbox_to_anchor=(1.51, 0.1)) #bbox_to_anchor=(1.41, -0.4)
    matplotlib.rcParams.update({'font.size': 20})
    if plot_num[0] == (plot_num[1] // 2) + 1:
        plt.xlabel('Size of Minority Group (Proporiton)')
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    else:
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    #if plot_num[0] == 1:
    plt.ylabel('Harmonic Mean of AUROC and AUEOC')
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    #plt.title(data_names_official[dataset])
    #plt.subplots_adjust(right=0.9, top=0.92, left=0.05, bottom=0.15, wspace=0.2)
    if plot_num[1] > 1:
        plt.subplots_adjust(right=0.9, top=0.92, left=0.05, bottom=0.15, wspace=0.2)
        plt.title(data_names_official[dataset])
    else:
        plt.subplots_adjust(right=0.71, top=0.92, left=0.13, bottom=0.15)
    plt.savefig(plots_dir + 'minprop_' + str(val_gt) + '.png', )

def postprocess_anc_repl(dataset_name, approaches, date, seeds, conditions, cond_lab, val_gt, exp_name, plot_num, color=None):
    num_cond, num_seeds = len(conditions), len(seeds)
    res_all = []  

    for i in range(len(approaches)):
        approach = approaches[i]
        seed_res = np.zeros((num_cond, num_seeds))
        for j in range(len(conditions)):
            cond = conditions[j]
            for k in range(len(seeds)):
                seed = seeds[k]
                #new_date = date + '-' + str(cond)
                new_date = date + 's' + str(seed) + 'sa_' + exp_name + '-' + str(cond)
                if cond == -3:
                    new_date = date + 's' + str(seed) + 'sa_' + 'size' + '-7'       
                file_name = results_dir + dataset_name + '/' + new_date + '_' + approach + '_' + str(val_gt) + '.pkl'
                file_handle = open(file_name, 'rb')
                res = pickle.load(file_handle)
                file_handle.close()
                seed_res[j, k] = res['hm']
        res_all.append(seed_res) 

    if plot_num[0] == 1:
        plt.clf()
        width = 25
        if plot_num[1] == 1:
            width = 8
        plt.figure(figsize=(width, 5)) #25
    if exp_name == 'size':
        plot_anc_size(dataset_name, approaches, date, conditions, cond_lab, val_gt, res_all, plot_num, color)
    elif exp_name == 'bias':
        plot_anc_distr(dataset_name, approaches, date, conditions, cond_lab, val_gt, res_all, 10, plot_num, color)
    elif exp_name == 'minprop':
        plot_minprop(dataset_name, approaches, date, conditions, cond_lab, val_gt, res_all, 40, plot_num, color)
    elif 'noise_disp' in exp_name:#exp_name == 'noise_rate' or exp_name == 'noise_rate_rand':
        plot_noise_rate(dataset_name, approaches, date, conditions, cond_lab, val_gt, res_all, 30, plot_num, exp_name, colour=None)
    elif 'noise' in exp_name:#exp_name == 'noise_rate' or exp_name == 'noise_rate_rand':
        plot_noise_rate(dataset_name, approaches, date, conditions, cond_lab, val_gt, res_all, 20, plot_num, exp_name, colour=None)


#########################################################################################
'''
plot multiple datasets at once when varying anchor set
'''
def plot_exp(dataset_names, approaches, date, seeds, conditions, cond_lab, val_gt, exp_name):
    colors = ['C0', 'C1', 'C2', 'C3', 'C4']
    for i in range(len(dataset_names)):
        dataset_name = dataset_names[i]
        plot_num = [i+1, len(dataset_names)]
        postprocess_anc_repl(dataset_name, approaches, date, seeds, conditions, cond_lab, val_gt, exp_name, plot_num, color=colors[i])


'''
plot general results
'''
def plot_general(dataset_names, approaches, date, val_gt):
    datasets = list(dataset_names.keys())
    for i in range(len(datasets)): 
        dataset = datasets[i]
        make_bar_graph_split(dataset, approaches, date, val_gt, show=True)
        make_test_plot1(dataset, approaches, date, val_gt, show=True)
        make_test_plot4(dataset, approaches, date, val_gt, show=True)


######################################################################################
'''
main block
'''
#rate_cond = [1,6,10,13] 
#disp_cond = [5,6,7,8]
#rand_cond = [0,1,2,3,4]

if __name__ == '__main__':
    print(':)')
    date = '0520'
    val_gt = True
    approaches = ['baseline_plain', 'baseline_sln_filt', 'baseline_js_loss', 'baseline_transition', \
                  'baseline_transit_conf', 'baseline_fair_gpl', 'proposed1', 'baseline_plain_clean']
    #approaches = ['baseline_transition', 'proposed1']
    dataset_names = ['synth_feat2', 'MIMIC-ARF-feat2', 'MIMIC-Shock-feat2', 'adult-feat2', 'compas-feat2']
    #dataset_names = ['synth_feat2', 'adult-feat2', 'compas-feat2']
    seeds = np.arange(10)
    cond_lab = 'anchor_props'

    size_cond = np.arange(1, 10, 1)
    #distr_cond = [0,1,3,5,7,9,10,11]#np.arange(0, 12, 1) #[0,2,3,4,6,8,10,11]#[0,1,3,5,7,9,10,11]#[0,2,3,6,8,10,11]
    distr_cond = np.arange(0, 10)

    print('varying size of anchor set')
    plot_exp(dataset_names, approaches, date, seeds, size_cond, cond_lab, val_gt, 'size')
    print('varying bias in anchor set')
    plot_exp(dataset_names, approaches, date, seeds, distr_cond, cond_lab, val_gt, 'bias')
    
    print('varying size of minority group')
    minprop_cond = [0,2,5,7,9]#np.arange(10)
    minprop_cond = np.arange(10)
    cond_lab = 'prop_min'
    dataset_names = ['synth_feat2']
    plot_exp(dataset_names, approaches, date, seeds, minprop_cond, cond_lab, val_gt, 'minprop')
    
    cond_lab = 'noise_rate'
    dataset_names = ['synth_feat2']
    print('varying noise rate')
    rate_cond = np.arange(0, 10, 1)
    plot_exp(dataset_names, approaches, date, seeds, rate_cond, cond_lab, val_gt, 'noise_rate')
    print('varying noise disparity')
    disp_cond = np.arange(0, 10, 1)
    plot_exp(dataset_names, approaches, date, seeds, disp_cond, cond_lab, val_gt, 'noise_disp')
    
    dataset_names = ['synth_random', 'MIMIC-ARF-random', 'MIMIC-Shock-random', 'adult-random', 'compas-random']
    dataset_names = ['synth_random']#, 'adult-random', 'compas-random']
    print('random noise experiments')
    rand_cond = np.arange(0, 7, 1)
    plot_exp(dataset_names, approaches, date, seeds, rand_cond, cond_lab, val_gt, 'noise_rate_rand')
    
