from algs import *
import torch
import numpy as np
import time
import matplotlib.pyplot as plt

from config import envs_config_antmaze as conf
from config import dyn_config_antmaze as dyn_conf
from config.envs_config_antmaze import start_timer, set_seeds, config_log, start_print

from dynamics.utils import DynamicsBuffer, plot_smoothed, MC_se, setup_plot, ant_get_block

# set seeds
random_seed = [(x + 3) ** 2 for x in range(1)]  # set random seed if required (0 = no random seed)
run_num = random_seed  # logging stamp
run_num_pretrained = 0  # change this to prevent overwriting weights in same env_name folder

# Configure environment settings and hyperparameters
env_name = 'PointMaze_UMaze-v3'

# Setup agent and directory for logging results
my_agents = ['PPO_DKL']
labels = dict(zip(my_agents,
                  ['No Expl', 'ICM', 'VIME', 'MAX', 'BAE-GP', 'BALE']))

# Setup plot and colors
fig, axs, colors = setup_plot(my_agents)
fig.suptitle('Point Maze', fontsize=16, fontweight='bold')

# Agent loop
for my_agent, color in zip(my_agents, colors):

    print('\nAgent: ', my_agent)

    # list to keep track of visited states and rewards per seed
    states_perc_list, reward_list, time_list = [], [], []

    for rnd_seed in random_seed:

        print('\nSeed: ', rnd_seed)

        # Set seeds
        set_seeds(rnd_seed)

        # Restart environment
        conf_env = conf.ConfigRun(env_name, max_ep_len=500, update_timestep=20)

        # Configure logging of percentage of states visited
        percentage_states_visited = torch.zeros(conf_env.max_ep_len)

        reward_tot = 0
        reward_per_step = torch.zeros(conf_env.max_ep_len)

        # Create agent and dynamics model
        agent = conf_env.create_agent(my_agent)
        # Models I want here are: PPO_GP, PPO_DKL, PPO_BALE, PPO_MAX
        if my_agent in ['PPO_VIME', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']:
            dyn_model = dyn_conf.create_dyn_model(conf_env, my_agent, dyn_hidden=32, num_batches=5, dyn_layers=2,
                                                  ensemble_size=10, update_every=conf_env.update_every, num_ind_pts=50)
            dyn_data = None
        else:
            # PPO has no dynamics model and PPO_ICM has dynamics model within the agent class directly
            dyn_model, dyn_data = None, None

        # Create buffer for active exploration RL:
        if my_agent in ['PPO_VIME', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']:
            dyn_data = DynamicsBuffer()
            dyn_model.setup_optimizer() if my_agent in ['PPO_MAX'] else print('')

        # Start timer
        start_time = time.time()

        # training loop
        while conf_env.time_step < conf_env.max_ep_len:

            # Get initial observations plus position goal, we set seed to set the same maze structure
            curr_obs = conf_env.env.reset(seed=0)[0]

            # Get state and current + final position
            state = curr_obs['observation']
            final_pos = curr_obs['desired_goal']
            curr_pos = curr_obs['achieved_goal']
            initial_pos = curr_pos

            print('\n\nGoal position: ', final_pos, '\n\n')

            # Divide the x and y range respectively into bins with specified width
            # The x range is double considering the U shape of the maze
            x_up_bins = np.arange(curr_pos[0], -curr_pos[0], 0.01)
            y_bins = np.arange(final_pos[1], -final_pos[1], 0.01)
            x_low_bins = np.arange(curr_pos[0], -curr_pos[0], 0.01)

            # Define the visited bins vectors made of zeros
            x_up_visited = np.zeros(len(x_up_bins))
            y_visited = np.zeros(len(y_bins))
            x_low_visited = np.zeros(len(x_low_bins))

            tot_bins = len(x_up_visited) + len(y_visited) + len(x_low_visited)

            for t in range(1, conf_env.max_ep_len + 1):

                # Find the indices of the bin that the current position falls into

                # Firstly update y_visited_bins easier
                y_up_index = np.digitize(curr_pos[1], y_bins) - 1
                y_visited[y_up_index] = 1

                # Then to update x_visited_bins we need to check if we are in the upper or lower part of the maze
                if curr_pos[1] > 0:
                    x_up_index = np.digitize(curr_pos[0], x_up_bins) - 1
                    x_up_visited[x_up_index] = 1
                else:
                    x_low_index = np.digitize(curr_pos[0], x_low_bins) - 1
                    x_low_visited[x_low_index] = 1

                # Finally update the percentage of states visited
                perc_sts = (np.sum(x_up_visited) + np.sum(y_visited) + np.sum(x_low_visited)) / tot_bins
                percentage_states_visited[conf_env.time_step] = perc_sts

                # Print current progress in visited states
                if t % 100 == 0:
                    print('\nPercentage of states visited at time step {}: '.format(t), format(perc_sts, ".2%"))
                    print('Current position: ', curr_pos)
                    print('Initial position: ', initial_pos)

                # Episode update
                if my_agent in ['PPO', 'PPO_ICM', 'PPO_VIME']:
                    state, curr_pos, done = conf_env.update_reactive(agent, dyn_model, state, expl_mode=my_agent)

                else:
                    state, curr_pos, done = conf_env.update_active(agent, dyn_model, state, expl_mode=my_agent,
                                                                   dynam_data=dyn_data, update_every=10, t=t,
                                                                   warm_start=50, trajs=10, im_h=100)

                # Stop if agent has reached the goal
                if done:
                    break

            # Break if agent has reached the goal
            if done:
                break

        # Print final total reward and max reward
        print('\n\nTotal reward: ', reward_tot)
        print('Max reward: ', reward_per_step.max().item())

        # Print percentage of states visited as percentage
        print('\nFinal percentage of states visited: ', format(percentage_states_visited[-1].item(), ".2%"), '\n\n')

        # Save time
        time_list.append(time.time() - start_time)

        # Save percentage of states visited
        states_perc_list.append(percentage_states_visited)
        reward_list.append(reward_per_step)

    # Plot average percentage of states visited and 95% confidence interval
    states_tsr = torch.stack(states_perc_list)
    reward_tsr = torch.stack(reward_list)

    # Plot percentage of states visited and rewards, but one plot for reactive and one for active
    if my_agent in ['PPO', 'PPO_ICM', 'PPO_VIME']:
        plot_smoothed(axs[0], states_tsr, labels[my_agent], 'States visited (%)', color=color)
    else:
        plot_smoothed(axs[1], states_tsr, labels[my_agent], 'States visited (%)', color=color)

    # Compute mean and MC standard error of the percentage of states visited and write to file
    states_mean = states_tsr[:, -1].mean(dim=0)
    states_se = MC_se(states_tsr[:, -1], len(random_seed))

    with open('plots/{}_StatesExpl_Runs{}_FINAL.csv'.format(env_name, len(random_seed)), 'a') as f:
        f.write('{},{},{}\n'.format(my_agent, states_mean, states_se))

    # Compute mean and MC standard error of the rewards at the last time step and write to file
    reward_mean = reward_tsr[:, -1].mean(dim=0)
    reward_se = MC_se(reward_tsr[:, -1], len(random_seed))

    with open('plots/{}_FinalReward_Runs{}_FINAL_BALE.csv'.format(env_name, len(random_seed)), 'a') as f:
        f.write('{},{},{}\n'.format(my_agent, reward_mean, reward_se))

    # Compute mean and MC standard error of time list and write to file
    time_tsr = torch.tensor(time_list)
    time_mean = time_tsr.mean(dim=0)
    time_se = MC_se(time_tsr, len(random_seed))

    with open('plots/{}_Time_Runs{}_FINAL_BALE.csv'.format(env_name, len(random_seed)), 'a') as f:
        f.write('{},{},{}\n'.format(my_agent, time_mean, time_se))


# Save figure
plt.savefig('plots/{}_StatesExpl_Runs{}_FINAL_BALE.pdf'.format(env_name, len(random_seed)), bbox_inches='tight')

