import enum
import argparse
import pickle

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from utils import torch_predict, Direction, markers, colors, load_data, load_maze_img
from plot import plot_labels, aggregated_plot, plot_loss_curves, plot_acc_curves

# Increase matplotlib font sizes for readability
plt.rcParams.update({
    'font.size': 14,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 13,
})

def aggregated_plot_final(last_plots, args, ax, title="", metric="test_acc", postfix="", linestyle="solid"):
        if args.add_goal:
            postfix += "-GC"
            linestyle = "--"
        
        # Plot last performance averaged over 
        ax.errorbar(args.train_splits, 
                    [np.mean(last_plots[s]["policy"][metric]) for s in args.train_splits],
                    [np.std(last_plots[s]["policy"][metric]) for s in args.train_splits], label="BC"+postfix, capsize=3, marker=".", color="tab:blue", linestyle=linestyle)
        #plt.plot(args.train_splits, [np.mean(data_plots[s]["idm_test_acc"]) for s in args.train_splits], label="IDM Test Accuracy")
        ax.errorbar(args.train_splits, 
                    [np.mean(last_plots[s]["idm"][metric]) for s in args.train_splits],
                    [np.std(last_plots[s]["idm"][metric]) for s in args.train_splits], label="IDM-based"+postfix, capsize=3, marker=".", color="tab:orange", linestyle=linestyle)
        ax.set_title(title)
        if metric == "test_acc":
            ax.set_ylim(0.2, 1.05)
        elif metric == "test_loss":
            print("")
            #ax.set_ylim(0, 50)
            #ax.set_ylim(0.2, 1.05)
        #ax.set_xlabel("Training Size")
        #ax.set_ylabel("Test accuracy")
        return ax

if __name__ == "__main__":
    ### Maze complexity plot, no conditioning ###
    exp_root = 'results_stoch_final'
    metric = "avg_reward"

    for expert in ["deterministic", "stochastic"]:
        for num_lab_samples in [1000, 100]:
            output_folder = f"{exp_root}/MLP_{expert}_20env_1000n_{num_lab_samples}l_5seeds"
            with open(f"{output_folder}/args.pkl", "rb") as f:
                exp_args = pickle.load(f)
            with open(f"{output_folder}/last_plots.pkl", "rb") as f:
                last_plots = pickle.load(f)
            #import ipdb; ipdb.set_trace()
            print(f"{expert}, {num_lab_samples/10}% of labels")
            print(f"BC policy: {np.mean(last_plots['policy']['avg_rew'])} +- {np.std(last_plots['policy']['avg_rew'])}")
            print(f"IDM policy: {np.mean(last_plots['idm_policy']['avg_rew'])} +- {np.std(last_plots['idm_policy']['avg_rew'])}")
            #print('avg_dist')
            #print(f"BC policy: {np.mean(last_plots['policy']['avg_dist'])} +- {np.std(last_plots['policy']['avg_dist'])}")
            #print(f"IDM policy: {np.mean(last_plots['idm_policy']['avg_dist'])} +- {np.std(last_plots['idm_policy']['avg_dist'])}")
