import argparse
import os
import json

import numpy as np
import matplotlib.pyplot as plt

def plot_1stdv(x, y_models, ylabel, xlabel, store_dir):
    models = y_models.keys()
    plt.figure(figsize=(10,8))
    for m in models:
        #clipped = y_models[m].clip(-200,200)
        clipped = y_models[m]
        mean = np.nanmean(clipped, axis=0)
        sigma = np.nanstd(clipped, axis=0) 
        plt.plot(x, mean, label=m)
        plt.fill_between(x, mean+sigma, mean-sigma, 
            alpha=0.5)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.legend(loc="upper left")
    #plt.ylim([0, 12])
    plt.savefig(os.path.join(store_dir, ylabel))
    plt.close()

def plot_mean(x, y_models, ylabel, xlabel, store_dir):
    models = y_models.keys()
    plt.figure(figsize=(10,8))
    for m in models:
        mean = y_models[m]
        plt.plot(x, mean, label=m)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.legend(loc="upper left")
    plt.savefig(os.path.join(store_dir, ylabel))
    plt.close()

def plot_1stdv_multi(x, metrics, ylabel, title, store_dir):
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    tick_size = 12
    title_size = 18
    for i, metric in enumerate(metrics):
        models = metrics[metric].keys()
        for m in models:
            clipped =metrics[metric][m].clip(-200,200)
            #clipped = y_models[m]
            mean = np.nanmean(clipped, axis=0)
            sigma = np.nanstd(clipped, axis=0) 
            ax[i].plot(x, mean, label=m)
            ax[i].fill_between(x, mean+sigma, mean-sigma, 
                alpha=0.5)
        ax[i].set_title(metric, fontdict={'fontsize':title_size, 'style': 'italic'})
        ax[i].tick_params(axis='x', labelsize=tick_size)
        ax[i].tick_params(axis='y', labelsize=tick_size)
    #plt.ylim([0, 12])
    fig.suptitle(title, fontsize=22)
    fig.tight_layout()
    plt.savefig(os.path.join(store_dir, ylabel))
    plt.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', default="Ant-v2_test_aquisition", type=str,
                        help='Environment [bimodal, hetero, WetChicken-v0, Pendulum-v0, Hopper-v2]')
    args = parser.parse_args()
    base_dir = '/home/nwaftp23/scratch/uncertainty_estimation/mujoco'
    env_dir = os.path.join(base_dir, args.env)
    run_dirs = os.listdir(env_dir)
    run_dirs = [os.path.join(env_dir, d) for d in run_dirs 
        if os.path.isdir(os.path.join(env_dir, d))]
    rmses_models = {}
    likelihoods_models = {}
    skip_seeds = False
    models_to_graph = 'nflows_ensemble_fixedmasks'
    title_multi = 'Nflows Ensemble'
    model_name = 'nflows_ensemble'
    #models_to_graph = 'nn_ensemble_fixedmasks'
    #title_multi = 'NN Ensemble'
    #model_name = 'pens'
    #acq_crit = ['bhatt_exp', 'kl_exp', 'random', 'sample_bald', 'batchbald', 'bait']
    acq_crit = ['bhatt_exp', 'kl_exp', 'random', 'sample_bald', 'batchbald', 
        'bmdal_badge', 'bmdal_bait']
    #acq_crit = ['bhatt_exp', 'kl_exp', 'random', 'badge']
    #acq_crit = ['bhatt_exp', 'kl_exp', 'mutual_info', 'random']
    #acq_crit = ['bhatt_exp', 'random']#, 'mutual_info']
    cutoff = 100
    for crit in acq_crit:
        model_dirs = [d for d in run_dirs 
            if crit in os.path.basename(os.path.normpath(d))]
        model_dirs = [d for d in model_dirs if models_to_graph in d]
        rmses = []
        likelihoods = []
        for d in model_dirs:
            if d.split('_')[-1] == 'seed585':
                continue
            if skip_seeds:
                if d.split('_')[-1] == 'seed5':
                    continue
                if d.split('_')[-1] == 'seed6':
                    continue
                if d.split('_')[-1] == 'seed7':
                    continue
                if d.split('_')[-1] == 'seed8':
                    continue
                if d.split('_')[-1] == 'seed9':
                    continue
            if models_to_graph not in d:
                continue
            args_path = os.path.join(d,'commandline_args.txt')
            saved_parser = argparse.ArgumentParser()
            saved_args = saved_parser.parse_args()
            with open(args_path, 'r') as f:
                saved_args.__dict__ = json.load(f)
            train_len = saved_args.epochs_multiplier
            results_dir = os.path.join(d, 'results')
            rmse = np.load(os.path.join(results_dir, 'rmse_array.npy'))
            test_loss = np.load(os.path.join(results_dir, 'test_loss_array.npy')) 
            likelihood = -test_loss
            rmse_nans = np.where(np.isnan(rmse))
            likelihoods_nans = np.where(np.isnan(likelihood))
            print(d)
            print(rmse.shape)
            rmses.append(rmse[:cutoff])
            likelihoods.append(likelihood[:cutoff])
        rmses = np.stack(rmses)
        likelihoods = np.stack(likelihoods)
        rmses_models[crit] = rmses
        likelihoods_models[crit] = likelihoods
    train_size = np.linspace(0,cutoff-1, cutoff)*10+100 
    acquisition_batch = np.linspace(0,cutoff-1, cutoff)
    plot_1stdv(train_size, rmses_models, 'rmse_'+model_name, 'training size', env_dir)
    plot_1stdv(train_size, likelihoods_models, 'log_likelihood_'+model_name, 
        'training size', env_dir)
    likelihoods_pct_change_models = {m: 
        ((likelihoods_models[m].mean(0)-likelihoods_models[m].mean(0)[0])
        /np.abs(likelihoods_models[m].mean(0)[0])*100)
        for m in likelihoods_models.keys()}
    rmses_pct_change_models = {m: 
        (-(rmses_models[m].mean(0)-rmses_models[m].mean(0)[0])
        /np.abs(rmses_models[m].mean(0)[0])*100)
        for m in rmses_models.keys()}
    plot_mean(acquisition_batch, rmses_pct_change_models, 'rmse_pct_ch_'+model_name, 
        'aquisition batch', env_dir)
    plot_mean(acquisition_batch, likelihoods_pct_change_models, 
        'likelihoods_pct_change_'+model_name, 'aquisition batch', env_dir)
    metrics = {'Log Likelihood': likelihoods_models, 'RMSE': rmses_models}
    plot_1stdv_multi(train_size, metrics, models_to_graph, title_multi, env_dir)
