from algs import *
import torch
import numpy as np
import time
import matplotlib.pyplot as plt

from config import envs_config as conf
from config import dyn_config as dyn_conf
from config.envs_config import start_timer, set_seeds, config_log, start_print

from dynamics.utils import DynamicsBuffer, plot_smoothed, MC_se, setup_plot

import argparse


def process_arguments():
    parser = argparse.ArgumentParser(description='Example Python Script')

    # Define the command-line arguments
    parser.add_argument('--seed', type=int, required=True, help='Input seed')

    # Parse the arguments
    args = parser.parse_args()

    return args


def train():

    # set seeds
    args = process_arguments()

    random_seed = args.seed
    rnd_seed = (random_seed + 13) ** 2

    # Configure environment settings and hyperparameters
    env_names = ['Reacher-v4']

    # Setup agent and directory for logging results
    my_agents = ['PPO']

    # Setup plot and colors
    fig, axs, colors = setup_plot(my_agents)

    # Environment loop
    for env_name in env_names:

        # Agent loop
        for my_agent, color in zip(my_agents, colors):

            print('\nAgent: ', my_agent)

            # list to keep track of visited states and rewards per seed
            reward_list, time_list = [], []

            print('\nSeed: ', rnd_seed)

            # Set seeds
            set_seeds(rnd_seed)

            # Restart environment
            conf_env = conf.ConfigRun(env_name, update_timestep=20)

            # Configure logging of percentage of states visited
            reward_tot = 0
            reward_per_step = torch.zeros(conf_env.max_ep_len)

            # Create agent and dynamics model
            agent = conf_env.create_agent(my_agent)
            # Models I want here are: PPO_GP, PPO_DKL, PPO_BALE, PPO_MAX
            if my_agent in ['PPO_VIME', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']:
                dyn_model = dyn_conf.create_dyn_model(conf_env, my_agent, dyn_hidden=32, num_batches=2, dyn_layers=2,
                                                      ensemble_size=5, update_every=conf_env.update_every, num_ind_pts=50)
                dyn_data = None
            else:
                # PPO has no dynamics model and PPO_ICM has dynamics model within the agent class directly
                dyn_model, dyn_data = None, None

            # Create buffer for active exploration RL:
            if my_agent in ['PPO_VIME', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']:
                dyn_data = DynamicsBuffer()
                dyn_model.setup_optimizer() if my_agent in ['PPO_MAX'] else print('')

            # Start timer
            start_time = time.time()

            # training loop
            while conf_env.time_step < conf_env.max_ep_len:

                # Get initial reset state and exclude last state from state vector
                state, _ = conf_env.env.reset(seed=rnd_seed)
                if env_name == 'Reacher-v4':
                    # Remove obs 4,5 and 8,9,10
                    dist_to_goal = state[8:10]
                    state = np.delete(state, [4, 5, 10])


                for t in range(1, conf_env.max_ep_len + 1):

                    # Episode update
                    if my_agent in ['PPO', 'PPO_ICM', 'PPO_VIME']:
                        state, dist_to_goal, reward, done = conf_env.update_reactive(agent, dyn_model, state, expl_mode=my_agent)

                    else:
                        state, dist_to_goal, reward, done = conf_env.update_active(agent, dyn_model, state, expl_mode=my_agent,
                                                                     dynam_data=dyn_data, t=t, trajs=10, im_h=100)

                    # Update total reward and reward per step
                    reward_tot += reward
                    reward_per_step[t - 1] += reward

                    print('\nDistance to goal at step ', t, ': ', dist_to_goal)

                    # Stop if agent has reached the goal
                    # Break if distance to goal is less than 0.01 for both x and y in absolute value
                    if np.all(np.abs(dist_to_goal) < 1e-3):

                        reward_per_step[t - 1] += 100.0
                        break

                # Break if agent has reached the goal
                if np.all(np.abs(dist_to_goal) < 1e-2):

                    print('\n\nGoal reached!\n\n')

                    # Fill the rest of the reward per step vector with the last reward
                    reward_per_step[t:] = reward_per_step[t - 1]

                    break

            # Print final total reward and max reward
            print('\n\nTotal reward: ', reward_tot)
            print('Max reward: ', reward_per_step.max().item())

            # Save time
            time_list.append(time.time() - start_time)

            # Save percentage of states visited
            reward_list.append(reward_per_step)

            # Plot average percentage of states visited and 95% confidence interval
            reward_tsr = torch.stack(reward_list)

            # Save rewards as numpy array
            np.save('logs/' + env_name + '/' + my_agent + '/Rewards_' + str(rnd_seed) + '.npy', reward_tsr.numpy())


if __name__ == '__main__':
    train()
