import matplotlib.pyplot as plt
import numpy as np
import h5py
import torch

import yaml
import argparse

from error import eval_spread_error

'''
TODO:
    - [ ] implement printing a specific corruption, or all of them and output...
'''

# Get configuration for the setup
with open('../config.yaml') as config_file:
    config = yaml.load(config_file, Loader = yaml.FullLoader)

# Need to hardcode the corruptions in, otherwise there's a dependancy
methods = [     'vanilla',
                 'dropout',
                 'dropout_nofirst',
                 'ensemble',
                 'll_dropout',
                 'll_svi',
                 'svi',
                 'temp_scaling',
        ]

corruptions = [ 'gaussian_blur',
                 'brightness',
                 'defocus_blur',
                 'gaussian_noise',
                 'impulse_noise',
                 'shot_noise',
                 'fog',
                 'pixelate',
                 'elastic_transform',
                 'speckle_noise',
                 'glass_blur',
                 'spatter',
                 'saturate',
                 'contrast',
                 'frost',
                 'zoom_blur',
            ]

# Stuff for parsing
parser = argparse.ArgumentParser('Plot a histogram of the model statistic data.')

parser.add_argument('--data-dir', type=str, metavar='DIR', default=None,
                help='Directory where dataset is saved')
parser.add_argument('--dataset',type=str, choices=['imagenet', 'cifar10'])
parser.add_argument('--method', type=str, default='vanilla',
                choices=methods, help='Method used for the model')
parser.add_argument('--nbins', type=int, default='30',
                choices=range(1,100), help='Number of bins to use. (Default: 30)')
# parser.add_argument('--temp', type=str, default='n', choices=['y', 'n'],
#                 help='Use temperature calibrated model (Default: n)')
# parser.add_argument('--val', type=str, default='y', choices=['y', 'n'],
#             help='Use the indices that are validation indices (Default: y)')
parser.add_argument('--corruption', type=str, default='corrupt-static-gaussian_blur',choices=corruptions + ['all', 'none'], # change this to something to allow for shorter file names
                    help='Corruption to use for calibration. (Default: gaussian_blur)')

def compute(data, dataset, subset, intensity):
    '''
        (needs a better name)
        computes the basic stuff, just doing this to make the code cleaner
        subset : ['test', 'train', args.corruption] (FIX)
    '''

    # This may look foolish, and it probably is...
    idx = "%s-%d" % (subset, intensity) if intensity > 0 else subset
    y = data[idx]
    
    if dataset == 'cifar10':
        sl = 0 # for cifar10 we have predictions for 5 independent models
        labels = y['labels'][sl, :]
        output = y['probs'][sl, :]
    else:
        labels = y['labels'][:]
        output = y['probs'][:]
    
    pmax = np.max(output,axis = -1)
    top1 = np.argmax(output,axis = -1) == labels

    return (pmax, top1)

def main():

    args = parser.parse_args()

    # Not sure about keeping this open...
    path = config['data']['path']
    filename = config['data'][args.dataset]['filename']

    f = h5py.File('/'.join((path,filename)), 'r')

    # Not sure if this is still necessary since we hardcoded them
    # ignore = ['roll-%d' % i for i in range(4, 29, 4)]
    # other = ['svhn', 'train', 'valid', 'test', 'celeb_a']

    '''
    - computes data for all of the corruptions of a certain model/method
        in a certain dataset for binning / histogram, in other words
        the output is: an array of shape (5, 16, 2, datapoints)
    '''

    # Outputs the stuff in the `stats` dictionary
    # for k in stats.keys():
    #     print("%s: \u03BC = %.4f \u03C3 = %.4f" % (k, stats[k][0], stats[k][1]))

    # Begin plotting for all five plots
    x = np.arange(1, args.nbins + 1)
    fig, axs = plt.subplots(1, 6, sharey = True)



    for i in range(6):

        corruption = 'test' if i == 0 else args.corruption

        # Test (half calibrate, half validation) computations
        # I misunderstood this again, it will be fixed shortly...it's just getting
        # late!
        tc_pmax, tc_top1 = compute(f[args.method], args.dataset,
                                        corruption, intensity = i)

        # Calculate accuracies of all the models across each
        scores, spreads, calc_errs, we, pt, pe = eval_spread_error(\
                        f = -np.log(tc_pmax), C = tc_top1, nbins = args.nbins)

        stats =  {
                    'brier'     : [scores[0].mean(), scores[0].std()],
                    'nll'       : [scores[1].mean(), scores[1].std()],
                    'msce'      : [calc_errs[0].mean(), calc_errs[0].std()],
                    'ece'       : [calc_errs[1].mean(), calc_errs[1].std()],
                    'eospread'  : [spreads[3].mean(), spreads[3].std()],
                }

        basevalue = tc_top1.mean()

        # First bar plot
        axs[i].bar(x, height = pt - basevalue, bottom = basevalue, width = 0.5, color=[0.2, 0.2, 0.5])

        # Second bar plot
        axs[i].bar(x, height = pe - basevalue, bottom = basevalue, width = 0.25, color=[0., 0.7, 0.7])

        axs[i].set_title('Intensity %d' % (i+1))
        axs[i].set_xlabel('Bins')

    # Plot settings, and descriptions
    plt.ylim((0, 1))
    plt.legend(('\'Train\' test set', '\'Validation\' test set'))

    # plt.title(r'Uncertainty of %s $p_\max$ on %s using %s corruption%s' %
    #                 (   args.method,
    #                     args.dataset,
    #                     args.corruption,
    #                     '' if args.corruption != 'all' else 's',
    #                     )
    #             )

    plt.show()

if __name__=='__main__':
    main()
