import matplotlib.pyplot as plt
import numpy as np
import argparse
from prettytable import PrettyTable


parser = argparse.ArgumentParser(description='Make Error Plots for Pruning Experiment.')
parser.add_argument('--model', type=str, default='LeNet5',
                    help='Deep NN model. Possible Options: "LeNet5", "CNN2D", "AlexNet", "CIFAR-VGG"')
parser.add_argument('--dataset', default='MNIST-3_5',
                    help='Dataset used. "MNIST-3_5", "MNIST-4_9", "MNIST", "Fashion-MNIST", "CIFAR10", "CIFAR100"')
parser.add_argument('--image_size', type=int, default=32,
                    help='Input image size. Default 32')
parser.add_argument('--format', type=str, default='plot',
                    help='Format of performance report: either table of plot')
parser.add_argument('--bound', default='N',
                    help='Whether to print theoretical upper bound of algorithms. Y/N')
args = parser.parse_args()


# Configuration 
MODEL = args.model
DATASET = args.dataset
IMAGE_SIZE = args.image_size
FORMAT = args.format
BOUND = True if args.bound=='Y' else False

accs_zonotope_kmeans = {}
accs_neural_path_kmeans = {}
accs_thinet = {}
accs_rand_st = {}
accs_l1_st = {}

RHS_zonotope_kmeans = {}
RHS_neural_path_kmeans = {}

trash = {}

with open('results/results_{}_{}_{}.txt'.format(MODEL, DATASET, IMAGE_SIZE), 'r') as txt:
    for line in txt:
        if(len(line.split()) >= 3):

            if(line.split()[2] == 'bound'):
                if (line.split()[4] == 'zonotope_kmeans:'):
                    buffer = RHS_zonotope_kmeans
                if (line.split()[4] == 'neural_path_kmeans:'):
                    buffer = RHS_neural_path_kmeans

            else:
                if (line.split()[2] == 'zonotope_kmeans' ):
                    buffer = accs_zonotope_kmeans
                if (line.split()[2] == 'neural_path_kmeans'):
                    buffer = accs_neural_path_kmeans
                if (line.split()[2] == 'thinet'):
                    buffer = accs_thinet
                if (line.split()[2] == 'random_structured'):
                    buffer = accs_rand_st
                if (line.split()[2] == 'l1_structured'):
                    buffer = accs_l1_st
        
            if (line.split()[1] == 'Test' or line.split()[2] == 'bound'):
                ratio = float(line.split()[3]) if line.split()[1] == 'Test' else float(line.split()[5])
                acc = float(line.split()[4]) if line.split()[1] == 'Test' else float(line.split()[6])

                # if (ratio <= 0.5):
                if ratio not in buffer.keys():
                    buffer[ratio] = [acc]
                else:
                    buffer[ratio].append(acc)

        

for buffer in [accs_zonotope_kmeans, accs_neural_path_kmeans, accs_thinet, accs_rand_st, accs_l1_st, RHS_zonotope_kmeans, RHS_neural_path_kmeans]:
    for key in buffer.keys():
        buffer[key] = [np.mean(buffer[key]), np.std(buffer[key])]

mean_accs_zonotope_kmeans = [x[0] for x in accs_zonotope_kmeans.values()]
stds_zonotope_kmeans = [x[1] for x in accs_zonotope_kmeans.values()]

mean_accs_neural_path_kmeans = [x[0] for x in accs_neural_path_kmeans.values()]
stds_neural_path_kmeans = [x[1] for x in accs_neural_path_kmeans.values()]

mean_accs_thinet = [x[0] for x in accs_thinet.values()]
stds_thinet = [x[1] for x in accs_thinet.values()]

mean_accs_rand_st = [x[0] for x in accs_rand_st.values()]
stds_rand_st = [x[1] for x in accs_rand_st.values()]

mean_accs_l1_st = [x[0] for x in accs_l1_st.values()]
stds_l1_st = [x[1] for x in accs_l1_st.values()]

mean_RHS_zonotope_kmeans = [x[0] for x in RHS_zonotope_kmeans.values()]
mean_RHS_neural_path_kmeans = [x[0] for x in RHS_neural_path_kmeans.values()]

RATIOS = accs_neural_path_kmeans.keys()
MAX = mean_accs_neural_path_kmeans[0]
upper_line = int(MAX // 10 + 1)

if (FORMAT == 'table'):
    if DATASET in ['MNIST-3_5', 'MNIST-4_9']:
        if(BOUND):
            t = PrettyTable(['Percentage of Remaining Neurons', 'Zonotope K-means Bound', 'Neural Path K-means Bound'])
            for i, f in enumerate(RATIOS):
                t.add_row(['{:.0f}'.format(100 * f), '{:.2f}'.format(mean_RHS_zonotope_kmeans[i])
                                , '{:.2f}'.format(mean_RHS_neural_path_kmeans[i])])
        else:
            t = PrettyTable(['Percentage of Remaining Neurons', 'Zonotope K-means', 'Neural Path K-means'])
            for i, f in enumerate(RATIOS):
                t.add_row(['{:.1f}'.format(100 * f), '{:.2f}+/-{:.2f}'.format(mean_accs_zonotope_kmeans[i], stds_zonotope_kmeans[i])
                                , '{:.2f}+/-{:.2f}'.format(mean_accs_neural_path_kmeans[i], stds_neural_path_kmeans[i])])
            
    else:
        t = PrettyTable(['Percentage of Remaining Neurons', 'Neural Path K-means'])
        for i, f in enumerate(RATIOS):
            t.add_row(['{:.1f}'.format(100 * f), '{:.2f}+/-{:.2f}'.format(mean_accs_neural_path_kmeans[i], stds_neural_path_kmeans[i])])
    
    print(t)

elif (FORMAT == 'plot'):
    # Adding gillsans font
    from matplotlib import font_manager
    font_dirs = ['fonts/']
    font_files = font_manager.findSystemFonts(fontpaths=font_dirs)
    for font_file in font_files:
        font_manager.fontManager.addfont(font_file)

    # Plot Experimental Results
    plt.style.use('ggplot')

    plt.rcParams['font.family'] = 'gillsans'
    plt.rcParams['xtick.color'] = 'black'
    plt.rcParams['ytick.color'] = 'black'

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.tick_params(length=0,  pad=10)

    plt.errorbar(RATIOS, mean_accs_neural_path_kmeans, yerr=stds_neural_path_kmeans, marker='p', label='Neural Path \nK-means', color = 'C1', linewidth=2)
    plt.errorbar(RATIOS, mean_accs_rand_st, yerr=stds_rand_st, marker='o', label='Random Structured', color = 'C3', linewidth=2)
    plt.errorbar(RATIOS, mean_accs_l1_st, yerr=stds_l1_st, marker='^', label='$L1$ Structured', color = 'C4', linewidth=2)
    
    if (len(mean_accs_thinet) > 0):
        plt.errorbar(RATIOS, mean_accs_thinet, yerr=stds_thinet, marker='v', label='ThiNet', color = 'C7', linewidth=2)
    
    plt.grid(axis='x', linewidth=0)
    plt.xlim(1, 0)
    plt.xticks([1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0], fontsize=16) 
    plt.yticks(10 *  np.arange(upper_line + 1), fontsize=16)
    ax.text(1.07, 10 * (upper_line) + 5, 'Accuracy (%)', rotation='horizontal', color='black', fontsize=16)
    plt.xlabel('Ratio of Remaining Neurons',  color='black', fontsize=16)
    plt.legend(loc='lower left', frameon=False, fontsize=18)
    plt.tight_layout()
    plt.savefig('plots/{}_trained_at_{}_imsize_{}.png'.format(MODEL, DATASET, IMAGE_SIZE))

