from algs import *
import torch
import numpy as np
import time
import matplotlib.pyplot as plt

from config import envs_config as conf
from config import dyn_config as dyn_conf
from config.envs_config import start_timer, set_seeds, config_log, start_print

from dynamics.utils import DynamicsBuffer, plot_smoothed, MC_se, setup_plot

import argparse


def process_arguments():
    parser = argparse.ArgumentParser(description='Example Python Script')

    # Define the command-line arguments
    parser.add_argument('--seed', type=int, required=True, help='Input seed')

    # Parse the arguments
    args = parser.parse_args()

    return args


def train():
    # set seeds
    args = process_arguments()

    random_seed = args.seed
    rnd_seed = (random_seed + 1) ** 2

    # Configure environment settings and hyperparameters
    env_names = ['PointMaze_Medium_Diverse_G-v3']

    # Setup agent and directory for logging results
    my_agents = ['PPO_MAX', 'PPO_GP', 'PPO_DKL']

    # Setup plot and colors
    fig, axs, colors = setup_plot(my_agents)

    # Environment loop
    for env_name in env_names:

        # Agent loop
        for my_agent, color in zip(my_agents, colors):

            print('\nAgent: ', my_agent)

            # list to keep track of visited states and rewards per seed
            reward_list, time_list = [], []

            print('\nSeed: ', rnd_seed)

            # Set seeds
            set_seeds(rnd_seed)

            # Restart environment
            conf_env = conf.ConfigRun(env_name, max_ep_len=5000, update_timestep=20)

            # Configure logging of percentage of states visited
            reward_tot = 0
            cum_reward = torch.zeros(conf_env.max_ep_len)

            # Create agent and dynamics model
            agent = conf_env.create_agent(my_agent)
            # Models I want here are: PPO_GP, PPO_DKL, PPO_BALE, PPO_MAX
            if my_agent in ['PPO_VIME', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']:
                dyn_model = dyn_conf.create_dyn_model(conf_env, my_agent, dyn_hidden=32, num_batches=2, dyn_layers=2,
                                                      ensemble_size=10, update_every=conf_env.update_every,
                                                      num_ind_pts=50)
                dyn_data = None
            else:
                # PPO has no dynamics model and PPO_ICM has dynamics model within the agent class directly
                dyn_model, dyn_data = None, None

            # Create buffer for active exploration RL:
            if my_agent in ['PPO_VIME', 'PPO_MAX', 'PPO_GP', 'PPO_DKL']:
                dyn_data = DynamicsBuffer()
                dyn_model.setup_optimizer() if my_agent in ['PPO_MAX'] else print('')

            # Start timer
            start_time = time.time()

            # training loop
            while conf_env.time_step < conf_env.max_ep_len:

                # Get initial observations plus position goal, we set seed to set the same maze structure
                curr_obs = conf_env.env.reset(seed=0)[0]

                # Get state and current + final position
                state = curr_obs['observation']
                final_pos = curr_obs['desired_goal']
                curr_pos = curr_obs['achieved_goal']
                initial_pos = curr_pos

                print('\n\nGoal position: ', final_pos, '\n\n')

                for t in range(1, conf_env.max_ep_len + 1):

                    # Print current progress in visited states
                    if t % 100 == 0:
                        print('\n\nTime step: ', t)
                        print('Current position: ', curr_pos)
                        print('Goal position: ', final_pos)
                        print('Reward per step: ', reward_tot)

                    # Episode update
                    if my_agent in ['PPO', 'PPO_ICM', 'PPO_VIME']:
                        state, curr_pos, reward, done = conf_env.update_reactive(agent, dyn_model, state,
                                                                                 expl_mode=my_agent)

                    else:
                        state, curr_pos, reward, done = conf_env.update_active(agent, dyn_model, state,
                                                                               expl_mode=my_agent,
                                                                               dynam_data=dyn_data, t=t, trajs=10,
                                                                               im_h=20)

                    # Update total reward and reward per step
                    reward_tot += reward
                    cum_reward[t - 1] = reward_tot

                    # Stop if agent has reached the goal
                    if reward_tot > 0:
                        break

                # Break if agent has reached the goal
                if reward_tot > 0:
                    # Fill the rest of the reward per step vector with the last reward
                    cum_reward[t:] = cum_reward[t - 1]

                    break

            # Print final total reward and max reward
            print('\n\nTotal reward: ', reward_tot)

            # Save time
            time_list.append(time.time() - start_time)

            # Save percentage of states visited
            reward_list.append(cum_reward)

            # Plot average percentage of states visited and 95% confidence interval
            reward_tsr = torch.stack(reward_list)

            # Save rewards as numpy array
            np.save('logs/' + env_name + '/' + my_agent + '/Rewards_' + str(rnd_seed) + '.npy', reward_tsr.numpy())


if __name__ == '__main__':
    train()
