import matplotlib.pylab as plt
import numpy as np
import pandas as pd

result_root = '~/result/over-parameterized-oed/'


def plot_mean_acc_with_std(df):
     plt.plot(df['mean_max_acc'])
     plt.fill_between(df.index, df['mean_max_acc'] + 2 * df['std_max_acc'],
                      df['mean_max_acc'] - 1 * df['std_max_acc'], alpha=0.5)


def plot_file(fname):
    plot_mean_acc_with_std(pd.read_csv(fname, index_col='mean_design_size'))


design_size = range(100, 801, 2)
acc_linear_our_design = pd.read_csv('mnist_lenet_oed.s0_l0_th1.ntk.csv')['acc'][design_size]

plot_file(result_root + '{width_factor: 1, design: {sigma: 0, lambda: 0, theta: 1}, kernel: ntk, loss: square}.csv')
plot_file(result_root + '{width_factor: 8, design: {sigma: 0, lambda: 0, theta: 1}, kernel: ntk, loss: square}.csv')
plot_file(result_root + '{width_factor: 1, design: uncertainty_sampling, loss: square}.csv')
plot_file(result_root + '{width_factor: 8, design: uncertainty_sampling, loss: square}.csv')
plot_file(result_root + '{width_factor: 1, design: coreset, kernel: ntk, loss: square}.csv')
plot_file(result_root + '{width_factor: 8, design: coreset, kernel: ntk, loss: square}.csv')
plot_file(result_root + '{width_factor: 1, design: random, loss: square}.csv')
plot_file(result_root + '{width_factor: 8, design: random, loss: square}.csv')

plt.legend(['Our algorithm - LeNet5',
            'Our algorithm - Wide-LeNet5',
            'Uncertainty sampling - LeNet5',
            'Uncertainty sampling - Wide-LeNet5',
            'Coreset (k-centers) - LeNet5',
            'Coreset (k-centers) - Wide-LeNet5',
            'Random design - LeNet5',
            'Random design - Wide-LeNet5',
            ])

plt.xlabel('training set size')
plt.ylabel('accuracy')

plt.yticks(np.arange(0.78, 0.98, 0.02))
plt.show()
#plt.savefig('result.pdf')
