import numpy as np
import matplotlib.pyplot as plt

dirs = ['diag_small','diag_large','full_small', 'full_large'] #'sgd_small','sgd_large',
task_labels = ['Task A', 'Task B', 'Task C']

trains = []
tests = []

for dirr in dirs:
    train = np.genfromtxt(dirr+'/train_accs.txt')
    test = np.genfromtxt(dirr+'/test_accs.txt')
    trains.append(train)
    tests.append(test)

# Algo, Task_track, Values
trains = np.array(trains)
tests = np.array(tests)
    
# Three subplots sharing both x/y axes
f, axs = plt.subplots(3, sharex=True, sharey=True)
for ax_idx in range(len(axs)):
    for i in range(trains.shape[0]):
        axs[ax_idx].plot(trains[i, ax_idx], label=dirs[i])
        axs[ax_idx].set_ylabel(task_labels[ax_idx])
        axs[ax_idx].set_ylim([0.65,1.0])
axs[0].set_title('EWC - Training Accuracies')
axs[0].set_xlabel("Epochs")
axs[2].legend(loc='upper left', bbox_to_anchor=(0.0, 0.9, 0.5, 0.5))
# Fine-tune figure; make subplots close to each other and hide x ticks for
# all but bottom plot.
f.subplots_adjust(hspace=0)
plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False)
plt.savefig("train.png")
plt.close()

# Three subplots sharing both x/y axes
f, axs = plt.subplots(3, sharex=True, sharey=True)
for ax_idx in range(len(axs)):
    for i in range(tests.shape[0]):
        axs[ax_idx].plot(tests[i, ax_idx], label=dirs[i])
        axs[ax_idx].set_ylabel(task_labels[ax_idx])
        axs[ax_idx].set_ylim([0.65,1.0])
axs[0].set_title('EWC - Test Accuracy')
axs[0].set_xlabel("Epochs")
axs[2].legend(loc='upper left', bbox_to_anchor=(0.0, 0.9, 0.5, 0.5))
# Fine-tune figure; make subplots close to each other and hide x ticks for
# all but bottom plot.
f.subplots_adjust(hspace=0)
plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False)
plt.savefig("test.png")
plt.close()
