import numpy as np
import matplotlib.pyplot as plt
import torch
import argparse
import os
from timeit import default_timer as timer
import time, datetime
import copy

# To read the dataset built by the google research team
import h5py
import yaml

from bar import compute
from error import eval_spread_error
from error import eval_spread_error_individual
from error import eval_spread_error_multi
from snoek import snoek_scores
from error import determine_edges

# Get configuration for the setup
with open('config.yaml') as config_file:
    config = yaml.load(config_file, Loader = yaml.FullLoader)


parser = argparse.ArgumentParser('Plot a boxplot for statistical data.')
parser.add_argument('--dataset',type=str, choices=['imagenet', 'cifar10'],
                        default='cifar10')
parser.add_argument('--nbins', type=int, default='30', choices=range(1, 101),
                        help='Number of bins to use. (Default: 30)')

parser.add_argument('--half',action='store_true', dest='half',
                    help = "Sample half of data for calibration")
parser.add_argument('--no-half',action='store_false', dest='half',
                    help = "Not sample half of data for calibration")
parser.set_defaults(half=False)

parser.add_argument('--peak',action='store_true', dest='peak',
                    help = "Only peak at the test set")
parser.add_argument('--no-peak',action='store_false', dest='peak',
                    help = "Look at the full test set")
parser.set_defaults(half=False)


def main():
    
    args = parser.parse_args()
    dataset = args.dataset
    nbins = args.nbins
    if dataset == 'cifar10':
        methods = ['vanilla',
               'temp_scaling',
               'ensemble',
               'dropout',
               'll_dropout',
               'svi',
               'll_svi']
        titles_methods = ['Vanilla','Temp Scaling','Ensemble','Dropout','LL Dropout','SVI','LL SVI']

    else:
        methods = ['vanilla',
                   'temp_scaling',
                   'ensemble',
                   'dropout',
                   'll_dropout',
                   'll_svi']
        titles_methods = ['Vanilla','Temp Scaling','Ensemble','Dropout','LL Dropout','LL SVI']


    corruptions = ['contrast',
                   'brightness',
                   'defocus_blur',
                   'elastic_transform',
                   'fog',
                   'frost',
                   'gaussian_blur',
                   'gaussian_noise',
                   'glass_blur',
                   'impulse_noise',
                   'pixelate',
                   'saturate',
                   'shot_noise',
                   'spatter',
                   'speckle_noise',
                   'zoom_blur']

    intensities = range(1, 6)

    # Grab the data from the file
    # Not sure about keeping this open...
    
    path = config['data']['path']
    filename = config['data'][args.dataset]['filename']
    
    print('/'.join((path,filename)))
    f = h5py.File('/'.join((path,filename)), 'r')

    # For intensity, corruption and method, we compute the accuracy, ECE, Brier Score and NLL
    ds = np.zeros((len(intensities)+1, len(methods), len(corruptions)-1, 4)) # change to variables
    ds_snoek = np.zeros((len(intensities)+1, len(methods), len(corruptions)-1, 4)) # change to variables
    ds_single = np.zeros((len(intensities)+1, len(methods), len(corruptions)-1, 4)) # change to variables
    ds_multi = np.zeros((len(intensities)+1, len(methods), len(corruptions)-1, 4)) # change to variables

    # We set the random seed so we partioned the data always in the same way into calibration set and true test set
    np.random.seed(42)
    if dataset=='cifar10':
        n = 10000 # number of examples in the datasets
    else:
        n = 49984
    ip = np.random.permutation(n)
    if args.half:
        ip = np.random.permutation(n)
        cal_ix = ip[:n//2]
        test_ix = ip[n//2:]
    else:
        ip = np.random.permutation(n)
        cal_ix = ip[:5000]
        test_ix = ip[5000:]

    for corruption_val_idx,corruption_val in enumerate(['contrast']):
        if args.half:
            save_to_folder_name = dataset+'_final_half/nbins'+str(nbins)+'/'+corruption_val+'/'
        else:
            if args.peak:
                save_to_folder_name = dataset+'_final/nbins'+str(nbins)+'peak/'+corruption_val+'/'
            else:
                save_to_folder_name = dataset+'_final/nbins'+str(nbins)+'/'+corruption_val+'/'

        try:
            os.makedirs(save_to_folder_name)
        except:
            pass
        try:
            os.makedirs(save_to_folder_name+'/'+'calibration_plots/')
        except:
            pass
#            print("Folder already exists")

        for intensity in range(len(intensities)+1):
            print('Intensity %d...' % intensity)
            start = timer()
            for mi in range(len(methods)):
                m = methods[mi]
                print('\tModel %s...' % m)
                
                # Determine the calibration set
                clean_pmax, clean_top1 = compute(f[m], dataset,'test', intensity = 0)
                
                f_cal = clean_pmax[cal_ix]
                C_cal = clean_top1[cal_ix]
                
                # Determine the augmented calibration sets
                f_cal_aug_save = {0:f_cal}
                C_cal_aug_save = {0:C_cal}
                for min_intensity in range(1,6):
                    # Determine the augmented calibration set
                    tcorr = [compute(f[m], dataset,'%s-%s' % ('corrupt-static',corruption_val),intensity =  i) for i in range(min_intensity,5+1)]
                    
                    tcorr_pmax_cal = [t[0][cal_ix] for t in tcorr]
                    tcorr_top1_cal = [t[1][cal_ix] for t in tcorr]
                    
                    clean_aug_pmax_cal = np.concatenate(tuple(tcorr_pmax_cal))
                    clean_aug_top1_cal = np.concatenate(tuple(tcorr_top1_cal))
                    
                    f_cal_aug_save[min_intensity] = np.concatenate((f_cal,clean_aug_pmax_cal))
                    C_cal_aug_save[min_intensity] = np.concatenate((C_cal,clean_aug_top1_cal))

                for corruption_ix, corruption in enumerate(set(corruptions).difference(set([corruption_val]))):

                    # Determine the pmax and top1 in 'corruption-intensity' data if intensity > 0
                    if intensity>0:
                        tc_pmax, tc_top1 = compute(f[m], dataset,'%s-%s' % ('corrupt-static', corruption), intensity =  intensity)
                    else:
                        tc_pmax = clean_pmax
                        tc_top1 = clean_top1
                    
                    # Determine test set
                    f_test = tc_pmax[test_ix]
                    C_test = tc_top1[test_ix]
                    
                    # Computing ECE with Snoek method
                    ece_snoek, brier_snoek, nll_brier, we_snoek, edges_snoek, pt_snoek, pe_snoek = snoek_scores(f_test, C_test, bins=determine_edges(f_test,nbins))
                    # Compute scores with Top1 method
                    scores, spreads, calc_errs, we, edges, pt, pe = eval_spread_error(f_cal,C_cal,f_test,C_test,nbins)
                    
                    stats =  {'brier'     : scores[0],
                        'nll'       : scores[1],
                        'msce'      : calc_errs[0],
                        'ece'       : calc_errs[1],
                        'eospread'  : spreads[3],}

                    # Place into new dataset and move onto the next one
                    ece = stats['ece'].mean()
                    brier = stats['brier'].mean()
                    nll = stats['nll'].mean()
                    trpl = (tc_top1.mean(), ece, brier, nll)
                    trpl_snoek = (tc_top1.mean(), ece_snoek, brier_snoek, nll_brier)
                    # Compute scores with augmented Top1 method
                    scores, spreads, calc_errs, we_single, edges_single, pt_single, pe_single = eval_spread_error_individual(f_cal,C_cal,f_cal_aug_save,C_cal_aug_save,f_test,C_test,nbins,2)

                    stats =  {'brier'     : scores[0],
                        'nll'       : scores[1],
                        'msce'      : calc_errs[0],
                        'ece'       : calc_errs[1],
                        'eospread'  : spreads[3],}

                    # Place into new dataset and move onto the next one

                    ece_single = stats['ece'].mean()
                    brier_single = stats['brier'].mean()
                    nll_single = stats['nll'].mean()
                    trpl_single = (tc_top1.mean(), ece_single, brier_single, nll_single)

                    scores, spreads, calc_errs, we_multi, edges_multi, pt_multi, pe_multi = eval_spread_error_multi(f_cal,C_cal,f_cal_aug_save,C_cal_aug_save,f_test,C_test,nbins,args.peak)

                    stats =  {'brier'     : scores[0],
                        'nll'       : scores[1],
                        'msce'      : calc_errs[0],
                        'ece'       : calc_errs[1],
                        'eospread'  : spreads[3],}

                    # Place into new dataset and move onto the next one
                    
                    ece_multi = stats['ece'].mean()
                    brier_multi = stats['brier'].mean()
                    nll_multi = stats['nll'].mean()
                    trpl_multi = (tc_top1.mean(), ece_multi, brier_multi, nll_multi)
                    
                    if intensity>0:
                        # Store the data
                        ds[intensity, mi, corruption_ix, :] = trpl
                        ds_snoek[intensity, mi, corruption_ix, :] = trpl_snoek
                        ds_single[intensity, mi, corruption_ix, :] = trpl_single
                        ds_multi[intensity, mi, corruption_ix, :] = trpl_multi
                    else:
                        for j in range(len(corruptions)-1):
                            ds[intensity, mi, j, :] = trpl
                            ds_snoek[intensity, mi, j, :] = trpl_snoek
                            ds_single[intensity, mi, j, :] = trpl_single
                            ds_multi[intensity, mi, j, :] = trpl_multi

                    fig = plt.figure(figsize=(10,5),dpi=500)
                    ax1 = fig.add_subplot(1,3,1)
                    ax2 = fig.add_subplot(1,3,2)
                    ax3 = fig.add_subplot(1,3,3)
                    axes = [ax1,ax2,ax3]
                    widths_snoek = (edges_snoek[1:]-edges_snoek[:-1])
                    pos_snoek = edges_snoek[:-1]+widths_snoek/2
                    widths = (edges[1:]-edges[:-1])
                    pos = edges[:-1]+widths/2
                    widths_single = (edges_single[1:]-edges_single[:-1])
                    pos_single = edges_single[:-1]+widths_single/2
                    widths_multi = (edges_multi[1:]-edges_multi[:-1])
                    pos_multi = edges_multi[:-1]+widths_multi/2
                    widths = [widths_snoek,widths_single,widths_multi]
                    pos = [pos_snoek,pos_single,pos_multi]
                    pt = [pt_snoek,pt_single,pt_multi]
                    pe = [pe_snoek,pe_single,pe_multi]
                    we = [we_snoek,we_single,we_multi]
                    ece = [ece_snoek,ece_single,ece_multi]
                    brier = [brier_snoek,brier_single,brier_multi]
                    titles = ['Ovadia et al','Single Image','Multi Image']

                    for i in range(len(axes)):
                        axes[i].bar(pos[i][we[i]>0],pt[i][we[i]>0],width=widths[i][we[i]>0],align='center',alpha=0.7)
                        axes[i].bar(pos[i][we[i]>0],pe[i][we[i]>0],width=widths[i][we[i]>0],align='center',alpha=0.5)
                        axes[i].set_xlim((0,1))
                        axes[i].set_title(titles[i])
                        axes[i].set_xlabel("ECE = {0:1.3f}%\nBrier = {1:1.3f}".format(ece[i]*100,brier[i]))
                        axes[i].legend(["Calibration","Test"],loc='upper left')
                        axes[i].set_aspect('equal')
                    if intensity>0:
                        plt.savefig(save_to_folder_name+'/'+'calibration_plots/'+m+'_'+corruption+'_'+str(intensity)+'.pdf', bbox_inches='tight')
                        plt.close()
                    else:
                        plt.savefig(save_to_folder_name+'/'+'calibration_plots/'+m+'_clean'+'.pdf', bbox_inches='tight')
                        plt.close()
                        break

                end = timer()
                print('Finished in {0:1.2f} seconds.'.format(end-start))
                seconds_left = (
                            (len(corruptions)-corruption_val_idx-1)*len(intensities)*len(methods)
                                +(len(intensities)-intensity-1)*len(methods)+
                                +(len(methods)-mi-1)
                                )*(end-start)
                expected_completion = '{0:%H:%M %Y-%m-%d}'.format(datetime.datetime.now()+datetime.timedelta(seconds=seconds_left))
                print("Expected completion: %s."%expected_completion)
                start = timer()

        np.save(save_to_folder_name+'ds.npy', ds)
        np.save(save_to_folder_name+'ds_snoek.npy', ds_snoek)
        np.save(save_to_folder_name+'ds_single.npy', ds_single)
        np.save(save_to_folder_name+'ds_multi.npy', ds_multi)
        print("Completed: {0:d}/{1:d}".format(corruption_val_idx+1,len(corruptions)))
        # Start plotting Accuracy, ECE and Brier score as three box+whisker plots

        ###### Plotting ########

        measures = ['Accuracy','ECE','Brier Score','NLL']
        calibration_methods = ['Ovadia et al','Single Image','Multi Image']
        locations = ['lower left', 'upper left', 'upper left','upper left']
        titles = ['Accuracy - '+dataset, 'ECE - '+dataset, 'Brier Score - '+dataset, 'NLL - '+dataset]
        width_box_plot = .07
        hatch_list = ['','///','...']
        pos = width_box_plot*np.arange(3*len(methods))
        if dataset == 'cifar10':
            colors = ['gray', 'brown', 'green', 'steelblue', 'lightsteelblue', 'orange', 'wheat']
        else:
            colors = ['gray', 'brown', 'green', 'steelblue', 'lightsteelblue', 'wheat']
        for measure_ix, measure in enumerate(measures):
            fig, ax = plt.subplots(figsize=(22, 5))
            for i in range(len(intensities)+1):
                posi = 1.7*i+pos
                box1 = ax.boxplot(ds_snoek[i,:,:,measure_ix].transpose(), widths = width_box_plot,positions = posi[::3],patch_artist=True, showfliers=False)
                box2 = ax.boxplot(ds_single[i,:,:,measure_ix].transpose(), widths = width_box_plot,positions = posi[1::3],patch_artist=True, showfliers=False)
                box3 = ax.boxplot(ds_multi[i,:,:,measure_ix].transpose(), widths = width_box_plot,positions = posi[2::3],patch_artist=True, showfliers=False)
                for box_ix,box in enumerate([box1,box2,box3]):
                    if i == 0:
                        for patch, color in zip(box['medians'], colors):
                            patch.set_color(color)
                            patch.set_linewidth(2)
                    else:
                        for patch in box['medians']:
                            patch.set_color("black")
                        
                        for patch, color in zip(box['boxes'], colors):
                            patch.set_facecolor(color)
                            patch.set(hatch=hatch_list[box_ix])

            legend1 = ax.legend(box1["boxes"], titles_methods, loc='upper left',fontsize=15)
            box_legend1 = copy.copy(box1['boxes'][0])
            box_legend1.set_facecolor('white')
            box_legend2 = copy.copy(box2['boxes'][0])
            box_legend2.set_facecolor('white')
            box_legend3 = copy.copy(box3['boxes'][0])
            box_legend3.set_facecolor('white')
            if dataset == 'imagenet':
                if measure == 'Brier Score':
                    ax.legend([box_legend1,box_legend2,box_legend3], calibration_methods, loc=(.15,.71),fontsize=15)
                else:
                    ax.legend([box_legend1,box_legend2,box_legend3], calibration_methods, loc=(.15,.55),fontsize=15)
            else:
                ax.legend([box_legend1,box_legend2,box_legend3], calibration_methods, loc=(.15,.55),fontsize=15)
            plt.gca().add_artist(legend1)
            ax.set_xticks(3*width_box_plot*len(methods)/2+1.7*np.arange(len(intensities)+1))
            ax.set_xticklabels(['Test'] + [ str(i) for i in range(1,len(intensities)+1)],fontsize=15)
            ax.set_yticklabels(np.round(ax.get_yticks(),1),fontsize=15)
            ax.set_ylabel(measure,fontsize=20)
            ax.set_xlabel('Corruption intensity',fontsize=20)
            ax.set_xlim(-width_box_plot,10)

            plt.savefig(save_to_folder_name+dataset+measure+'.png',bbox_inches='tight')
            plt.close()

            if dataset == 'cifar10':
                select_methods = [0,1,2,3,4,5,6]
            else:
                select_methods = [0,1,2,3,4,5]

            fig, ax = plt.subplots(1, len(select_methods), sharey=True,figsize=(10,3))
            markersize = 5
            for ix,method in enumerate(select_methods):
                ax[ix].plot(ds_snoek[:,method,:,measure_ix].mean(axis=1),marker='o',markersize = markersize)
                ax[ix].plot(ds_single[:,method,:,measure_ix].mean(axis=1),marker='^',markersize = markersize)
                ax[ix].plot(ds_multi[:,method,:,measure_ix].mean(axis=1),marker='s',markersize = markersize)
                ax[ix].set_title(titles_methods[ix])
                ax[ix].set_xticks(range(0,6))
            plt.subplots_adjust(wspace=1)
            lgd = fig.legend(calibration_methods,bbox_to_anchor=(1.,.65), loc="upper left")
            fig.text(0.5, 0, 'Corruption Intensity', ha='center')
            fig.text(0, 0.5, measures[measure_ix], va='center', rotation='vertical')
            fig.tight_layout()
            plt.savefig(save_to_folder_name+dataset+measure+'_means.pdf', bbox_inches='tight')
            plt.close()


if __name__=='__main__':
    main()
