import numpy as np
import torch
import matplotlib.pyplot as plt
import os

from dynamics.utils import DynamicsBuffer, plot_smoothed_rewards, MC_se, setup_plot_rewards


# Setup agent and directory for logging results
B = 10
env_names = ['PointMaze_Medium_Diverse_G-v3']
my_agents = ['PPO', 'PPO_ICM', 'PPO_VIME', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']
labels = dict(zip(my_agents,
                  ['No Expl', 'ICM', 'VIME', 'MAX', 'BAE-GP', 'BALE']))

# Setup plot and colors
fig, axs, colors = setup_plot_rewards(env_names, my_agents)

# For loop over all environments
for k, env_name in enumerate(env_names):

    # Create dict and assign keys
    rewards = dict()
    for agent in my_agents:
        rewards[agent] = []

    # For loop over all agents
    for agent, color in zip(my_agents, colors):

        # Get all files in the directory
        dir_path = 'logs/' + env_name + '/' + agent
        files = os.listdir(dir_path)

        # Load all files in the dict
        for file in files:
            # Load the file
            rewards[agent].append(np.load(dir_path + '/' + file))

        # Stack the arrays
        reward_tsr = np.stack(rewards[agent], axis=0).squeeze(axis=1)
        reward_tsr = torch.tensor(reward_tsr)

        # For each row (run), return the index of the first time the reward is equal to 1.0
        # This is the index of the first time the agent reaches the goal
        # Create a mask for elements equal to 1.0
        # Find the first index where the value is greater than 0.9 for each row
        boolean = (reward_tsr == 1.0)
        index = torch.argmax(boolean.float(), dim=1)
        index[index == 0] = 5000

        # Compute mean and MC standard error of the percentage of rewards visited and write to file
        steps_mean = index.float().mean(dim=0)
        steps_se = MC_se(index.float(), B)

        # Compute average and MCSE of last step rewarda
        rewards_mean = reward_tsr[:, -1].mean(dim=0)
        rewards_se = MC_se(reward_tsr[:, -1], B)

        with open('plots/{}_Steps_Runs{}_FINAL.csv'.format(env_name, B), 'a') as f:
            f.write('{},{},{}\n'.format(agent, steps_mean, steps_se))

        with open('plots/{}_Rewards_Runs{}_FINAL.csv'.format(env_name, B), 'a') as f:
            f.write('{},{},{}\n'.format(agent, rewards_mean, rewards_se))

    plt.savefig('plots/{}_Steps_Runs{}_FINAL_BALE.pdf'.format(env_name, B), bbox_inches='tight')

