import os
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser(description="plots the results")
parser.add_argument("-d", "--decay", type=float, default=.99, help="Decay value for running averages")
parser.add_argument("-b", "--batch", type=int, default=100, help="Batch size")
parser.add_argument("-e", "--epochs", type=int, default=5, help="Number of epochs")
# parser.add_argument("-w", "--width", type=int, default=18, help="Width")
# parser.add_argument("-d", "--depth", type=int, default=18, help="Depth")
# parser.add_argument("-o", "--dim_out", type=int, default=10, help="Output dimension")
parser.add_argument("-r", "--runs", type=int, default=5, help="Number of runs to average over")
parser.add_argument("-n", "--resnet", type=int, default=18, help="which resnet?")

args = parser.parse_args()

batch = args.batch
epochs = args.epochs
runs = args.runs
decay = args.decay
resnet = args.resnet

train_size=5e4
dir_name = f'exp_results/rn{resnet}_batch{batch}/'
title = f'ResNet{resnet}, Batch Size = {batch}, '


data = {}
epoch_step = train_size/batch
total_steps = epoch_step*epochs
#dir_name = "batch10_d15w15"    

metrics = ['Test Accuracy', 'Training Accuracy', 'Test Loss', 'Training Loss']

names = ['NAG, m=.99', 'ADAM', 'SGD, m=.99', 'AGNES, eta=.01 fixed',]
labels = {'NAG, m=.99':'NAG', 'SGD, m=.99':'SGD+momentum', 'AGNES, eta=.01 fixed':'AGNES', 'ADAM':'ADAM',}

max_acc = {}
#for filename in os.listdir(dir_name):
for name in names:
    data[name] = {metric:[] for metric in metrics}
    # if filename.startswith(str(i)) and filename.endswith(f"_{epochs}.pth"):
    #     #name = filename[1:-len(f"_{epochs}.pth")]
    #     if name not in names:
    #         names.append(name)
    #         
    
    for i in range(runs):
        filename = f'{name}_r{i}_{epochs}.pth'
        try:
            with open(os.path.join(dir_name,filename), 'rb') as file:
                temp = torch.load(file, map_location=torch.device('cpu'))
                data[name]['Test Loss'].append(np.array(temp['test_losses']))
                data[name]['Test Accuracy'].append(np.array(temp['test_accuracies']))
                running_averages = [temp['train_losses'][0]]
                for num in temp['train_losses']:
                    running_averages.append(decay*running_averages[-1] + (1-decay)*num)
                data[name]['Training Loss'].append(np.array(running_averages))
                running_averages = [temp['train_accuracies'][0]]
                for num in temp['train_accuracies']:
                    running_averages.append(decay*running_averages[-1] + (1-decay)*num)
                data[name]['Training Accuracy'].append(np.array(running_averages))
    #                 data[name]['Max Accuracy'].append(np.maximum.accumulate(data[name]['Test Accuracy'][i]))
    #                 max_acc[f'{name}_{i}']=data[name]['Max Accuracy'][i][-1]
        except FileNotFoundError:
            print(filename, "does not exist.")



metric='Test Accuracy'
plt.figure()
for name in names:
    mean = np.mean(data[name][metric], axis = 0)
    std = np.std(data[name][metric] , axis = 0)

    plt.plot(np.arange(0,total_steps+1,epoch_step), mean, label = labels[name])#, color = colors[name])
    plt.fill_between(np.arange(0,total_steps+1,epoch_step), mean+std, mean-std, alpha = 0.2)#, color = colors[name])

plt.title(title+metric)
plt.legend()
plt.ylim([.8,.95])
#plt.show()
plt.savefig(os.path.join(dir_name,title+metric))
#     plt.savefig(os.path.join(dir_name,title+metric+"_zoomed"))

metric='Training Accuracy'
plt.figure()
for name in names:
    mean = np.mean(data[name][metric], axis = 0)
    std = np.std(data[name][metric] , axis = 0)

    plt.plot(mean, label = labels[name])#, color = colors[name])
    plt.fill_between(np.arange(0,total_steps+1), mean+std, mean-std, alpha = 0.2)#, color = colors[name])

plt.title(title+metric)
plt.legend()
plt.ylim([.7,1])
#plt.show()
plt.savefig(os.path.join(dir_name,title+metric))


metric='Test Loss'
plt.figure()
for name in names:
    mean = np.mean(data[name][metric], axis = 0)
    std = np.std(data[name][metric] , axis = 0)

    plt.semilogy(np.arange(0,total_steps+1,epoch_step), mean, label = labels[name])#, color = colors[name])
    plt.fill_between(np.arange(0,total_steps+1,epoch_step), mean+std, mean-std, alpha = 0.2)#, color = colors[name])

plt.title(title+metric)
plt.legend()
#plt.show()
plt.savefig(os.path.join(dir_name,title+metric))

metric='Training Loss'
plt.figure()
for name in names:
    mean = np.mean(data[name][metric], axis = 0)
    std = np.std(data[name][metric] , axis = 0)

    plt.semilogy(mean, label = labels[name])#, color = colors[name])
    plt.fill_between(np.arange(0,total_steps+1), mean+std, mean-std, alpha = 0.2)#, color = colors[name])

plt.title(title+metric)
plt.legend()
#plt.show()
plt.savefig(os.path.join(dir_name,title+metric))
# plt.ylim([6e-2,1])
# plt.savefig(os.path.join(dir_name,title+metric+"_zoomed"))
