from algs import *
import torch
import numpy as np
import time
import matplotlib.pyplot as plt

from config import envs_config_antmaze as conf
from config import dyn_config_antmaze as dyn_conf
from config.envs_config_antmaze import start_timer, set_seeds, config_log, start_print

from dynamics.utils import DynamicsBuffer, plot_smoothed, MC_se, setup_plot, ant_get_block

import argparse


def process_arguments():
    parser = argparse.ArgumentParser(description='Example Python Script')

    # Define the command-line arguments
    parser.add_argument('--seed', type=int, required=True, help='Input seed')

    # Parse the arguments
    args = parser.parse_args()

    return args


def train():

    # set seeds
    args = process_arguments()

    random_seed = args.seed
    rnd_seed = (random_seed + 10) ** 2

    # Configure environment settings and hyperparameters
    env_name = 'PointMaze_UMaze-v3'

    # Setup agent and directory for logging results
    # my_agents = ['Random', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']
    my_agents = ['Random', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']

    # Setup plot and colors
    fig, axs, colors = setup_plot(my_agents)

    # Agent loop
    for my_agent, color in zip(my_agents, colors):

        print('\nAgent: ', my_agent)

        # list to keep track of visited states and rewards per seed
        states_perc_list, reward_list, time_list = [], [], []

        print('\nSeed: ', rnd_seed)

        # Set seeds
        set_seeds(rnd_seed)

        # Restart environment
        conf_env = conf.ConfigRun(env_name, max_ep_len=4000, update_timestep=20)

        # Configure logging of percentage of states visited
        percentage_states_visited = torch.zeros(conf_env.max_ep_len)

        reward_tot = 0
        reward_per_step = torch.zeros(conf_env.max_ep_len)

        # Create agent and dynamics model
        agent = conf_env.create_agent(my_agent)
        # Models I want here are: PPO_GP, PPO_DKL, PPO_BALE, PPO_MAX
        if my_agent in ['Random', 'PPO', 'PPO_ICM']:
            dyn_model, dyn_data = None, None
        else:
            dyn_model = dyn_conf.create_dyn_model(conf_env, my_agent, dyn_hidden=32, num_batches=2, dyn_layers=2,
                                                  ensemble_size=10, update_every=conf_env.update_every, num_ind_pts=50)
            dyn_data = None

        # Create buffer for active exploration RL:
        if my_agent not in ['Random', 'PPO', 'PPO_ICM']:
            dyn_data = DynamicsBuffer()
            dyn_model.setup_optimizer() if 'PPO_MAX' in my_agent else print('')

        # training loop
        while conf_env.time_step < conf_env.max_ep_len:

            # Get initial observations plus position goal, we set seed to set the same maze structure
            curr_obs = conf_env.env.reset(seed=0)[0]

            # Get state and current + final position
            state = curr_obs['observation']
            final_pos = curr_obs['desired_goal']
            curr_pos = curr_obs['achieved_goal']
            initial_pos = curr_pos

            print('\n\nGoal position: ', final_pos, '\n\n')

            # Divide the x and y range respectively into bins with specified width
            # The x range is double considering the U shape of the maze
            x_up_bins = np.arange(curr_pos[0], -curr_pos[0], 0.01)
            y_bins = np.arange(final_pos[1], -final_pos[1], 0.01)
            x_low_bins = np.arange(curr_pos[0], -curr_pos[0], 0.01)

            # Define the visited bins vectors made of zeros
            x_up_visited = np.zeros(len(x_up_bins))
            y_visited = np.zeros(len(y_bins))
            x_low_visited = np.zeros(len(x_low_bins))

            tot_bins = len(x_up_visited) + len(y_visited) + len(x_low_visited)

            for t in range(1, conf_env.max_ep_len + 1):

                # Find the indices of the bin that the current position falls into

                # Firstly update y_visited_bins easier
                y_up_index = np.digitize(curr_pos[1], y_bins) - 1
                y_visited[y_up_index] = 1

                # Then to update x_visited_bins we need to check if we are in the upper or lower part of the maze
                if curr_pos[1] > 0:
                    x_up_index = np.digitize(curr_pos[0], x_up_bins) - 1
                    x_up_visited[x_up_index] = 1
                else:
                    x_low_index = np.digitize(curr_pos[0], x_low_bins) - 1
                    x_low_visited[x_low_index] = 1

                # Finally update the percentage of states visited
                perc_sts = (np.sum(x_up_visited) + np.sum(y_visited) + np.sum(x_low_visited)) / tot_bins
                percentage_states_visited[conf_env.time_step] = perc_sts

                # Print current progress in visited states
                if t % 100 == 0:
                    print('\nPercentage of states visited at time step {}: '.format(t), format(perc_sts, ".2%"))
                    print('Current position: ', curr_pos)
                    print('Initial position: ', initial_pos)

                # Episode update
                if my_agent in ['PPO', 'PPO_ICM', 'PPO_VIME']:
                    state, curr_pos, done = conf_env.update_reactive(agent, dyn_model, state, expl_mode=my_agent)
                elif my_agent in ['Random']:
                    # Select random action
                    action = conf_env.env.action_space.sample()
                    state, reward, done, _, _ = conf_env.env.step(action)
                    conf_env.current_ep_reward += reward
                    conf_env.time_step += 1
                else:
                    state, curr_pos, done = conf_env.update_active(agent, dyn_model, state, expl_mode=my_agent,
                                                                   dynam_data=dyn_data, update_every=50, t=t,
                                                                   warm_start=500, trajs=10, im_h=20)

                # Stop if agent has reached the goal
                if done:
                    break

            # Break if agent has reached the goal
            if done:
                break

        # Save percentage of states visited
        states_perc_list.append(percentage_states_visited)
        reward_list.append(reward_per_step)

        # Plot average percentage of states visited and 95% confidence interval
        states_tsr = torch.stack(states_perc_list)
        reward_tsr = torch.stack(reward_list)

        # Save state percentages as numpy array
        np.save('logs/' + my_agent + '/PercStates_' + str(rnd_seed) + '.npy', states_tsr.numpy())

if __name__ == '__main__':
    train()
