import numpy as np
import torch
import matplotlib.pyplot as plt
import os

from dynamics.utils import DynamicsBuffer, plot_smoothed, MC_se, setup_plot, ant_get_block


# Setup agent and directory for logging results
B = 15
env_name = 'PointMaze_UMaze-v3'
my_agents = ['PPO', 'PPO_ICM', 'PPO_VIME', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']
labels = dict(zip(my_agents,
                  ['π-Entropy', '$\ell^2$ Error', 'VIME', 'BAE DeepEns', 'BAE GP', 'BAE DK']))

# Setup plot and colors
fig, axs, colors = setup_plot(my_agents)
fig.suptitle('U-Maze', fontsize=16, fontweight='bold')

# Create dict and assign keys
percs = dict()
for agent in my_agents:
    percs[agent] = []

# For loop over all agents
for agent, color in zip(my_agents, colors):

    # Get all files in the directory
    dir_path = 'logs/' + agent
    files = os.listdir(dir_path)

    # Load all files in the dict
    for file in files:
        # Load the file
        percs[agent].append(np.load(dir_path + '/' + file))

    # Exclude the last file
    percs[agent] = percs[agent][:-1]
    # del percs[agent][3]

    # Stack the list arrays into a tensor of shape (n_runs, n_steps)
    states_tsr = np.array(percs[agent])
    states_tsr = torch.tensor(states_tsr).squeeze(1)

    
    # Plot the mean and standard error
    # Plot percentage of states visited and rewards, but one plot for reactive and one for active
    if agent in ['PPO', 'PPO_ICM', 'PPO_VIME']:
        plot_smoothed(axs[0], states_tsr, labels[agent], 'States visited (%)', color=color)
    else:
        plot_smoothed(axs[1], states_tsr, labels[agent], 'States visited (%)', color=color)

    # Compute mean and MC standard error of the percentage of states visited and write to file
    states_mean = states_tsr[:, -1].mean(dim=0)
    states_se = MC_se(states_tsr[:, -1], B)

    with open('plots/{}_StatesExpl_Runs{}_FINAL.csv'.format(env_name, B), 'a') as f:
        f.write('{},{},{}\n'.format(agent, states_mean, states_se))


plt.savefig('plots/{}_StatesExpl_Runs{}_FINAL_BALE.pdf'.format(env_name, B),
            bbox_inches='tight')

