import numpy as np
import logging
import matplotlib.pyplot as plt
import json
import os
import argparse


parser = argparse.ArgumentParser(description='Plotting microbatch online learning results')
parser.add_argument('--loss_types', nargs='+', type=str, default=['mse'], help='List of loss functions used for training')
parser.add_argument('--noise_rates', nargs='+', type=str, default=[0, 0.2, 0.4, 0.6], help='List of symmetric noise rates')
parser.add_argument('--T', nargs='+', type=int, help='2**T microbatches')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='Experimenting dataset')

args = parser.parse_args()

logger = logging.getLogger()
loss_type_list = args.loss_types
dataset = args.dataset


noise_rate_list= args.noise_rates
plt.rcParams.update({'font.size': 18})
loss_fig, loss_axes = plt.subplots(1, len(noise_rate_list), sharey=True, figsize=(22, 6.5), dpi=720)
loss_fig.supxlabel('Epoch')
acc_fig, acc_axes = plt.subplots(1, len(noise_rate_list), sharey=True, figsize=(22,6.5), dpi=720)
acc_fig.supxlabel('Epoch')

for loss_type in loss_type_list:
    for i, noise_rate in enumerate(noise_rate_list):
        for T in args.T:
            logger.info(f'loss_type: {loss_type} \t noise_rate: {noise_rate}')

            save_dir = os.path.join(dataset, f'vary_b_T{T}')
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            name = f'vary_b_{dataset}_{T}_n{noise_rate}_{loss_type}'

            with open(os.path.join(save_dir, f'{name}_trainLossesMean.json'), "r") as fp:
                mean_train_losses = np.array(json.load(fp))
                range_norm = np.max(mean_train_losses) - np.min(mean_train_losses)
            with open(os.path.join(save_dir, f'{name}_trainLossesStd.json'), "r") as fp:
                std_train_losses = np.array(json.load(fp))
            with open(os.path.join(save_dir, f'{name}_accTestMean.json'), "r") as fp:
                mean_acc_test = np.array(json.load(fp))
            with open(os.path.join(save_dir, f'{name}_accTestStd.json'), "r") as fp:
                std_acc_test = np.array(json.load(fp))

            loss_axes[i].plot(mean_train_losses, label=f'{int(25000 / 2**T)} samples/batch')
            loss_axes[i].fill_between(range(len(mean_train_losses)), mean_train_losses - std_train_losses, mean_train_losses + std_train_losses, alpha=0.2)
            loss_axes[i].set_title(f'noise_rate = {noise_rate}')

            acc_axes[i].plot(mean_acc_test, label=f'{int(25000 / 2**T)} samples/batch')
            acc_axes[i].fill_between(range(len(mean_acc_test)), mean_acc_test - std_acc_test, mean_acc_test + std_acc_test, alpha=0.2)
            acc_axes[i].set_title(f'noise_rate = {noise_rate}')

    save_dir = dataset
    name = f'vary_b_{dataset}_{loss_type}'
    loss_axes[0].set_ylabel('Loss')
    loss_axes[0].set_aspect(1./loss_axes[0].get_data_ratio(), share=True)
    loss_axes[0].legend()
    loss_fig.tight_layout()
    loss_fig.savefig(os.path.join(save_dir, f'{name}_trainLosses.jpg'))
    acc_axes[0].set_ylabel('Accuracy')
    acc_axes[0].legend()
    acc_axes[0].set_aspect(1./acc_axes[0].get_data_ratio(), share=True)
    acc_fig.tight_layout()
    acc_fig.savefig(os.path.join(save_dir, f'{name}_accTest.jpg'))