# ========================================================================= #
# Filename:                                                                 #
#    runner_spar.py                                                         #
#                                                                           #
# Description:                                                              #
#    Convenience script to load parameters and train an sac agent           #
# ========================================================================= #

# Consider installing safety_gym (and mujoco) and switching to the corresponding conda environment

import json
import os, sys, pwd, argparse
import ipdb as pdb
import numpy as np
import torch
from ruamel.yaml import YAML

#from ruamel.yaml import YAML
from datetime import date, datetime, timezone

l2r_path = os.path.abspath(os.path.join(''))
if l2r_path not in sys.path:
    sys.path.append(l2r_path)
print(l2r_path)
from baselines.rl.spar.spar_gym import SPARGym

import gym
import safety_gym
from configs.gym_cfg import gym_cfg 
from safety_gym.envs.engine import Engine
from gym.envs.registration import register

sys.path.append('/data/workspaces/jmf1/safety-starter-agents-fork/')
import safe_rl
from safe_rl.utils.run_utils import setup_logger_kwargs
from safe_rl.utils.mpi_tools import mpi_fork

current_user = pwd.getpwuid(os.getuid()).pw_name

def main(exp_name, cfg, args):

    # verify experiment
    robot_list = ['point', 'car', 'doggo']
    task_list = ['goal1']
    algo_list = ['ppo', 'spar']

    task = args.task.capitalize()
    robot = args.robot.capitalize()
    algo = args.algo.lower()
    assert task.lower() in task_list, "Invalid task"
    assert robot.lower() in robot_list, "Invalid robot"
    assert algo in algo_list, "Invalid algo"

    # hyperparams
    exp_name = f'spar_gym_{robot}{task}'
    if robot=='Doggo':
        num_steps = 1e8
        steps_per_epoch = 60000
    else:
        num_steps = 1e7
        steps_per_epoch = 30000
    epochs = int(num_steps / steps_per_epoch)
    save_freq = 1
    target_kl = 0.01
    cost_lim = 25

    # Fork for parallelizing
    mpi_fork(args.cpu)

    # Prepare Logger
    exp_name = exp_name or ('spar_gym_' + robot.lower() + task.lower())
    logger_kwargs = setup_logger_kwargs(exp_name, args.seed)

    # Base algorithm to wrap
    # algo = getattr(safe_rl, args.algo)

    # Env
    env_name = 'Safexp-'+robot+task+'-v1'
    env = Engine(config=gym_cfg)

    register(id=env_name,
         entry_point='safety_gym.envs.mujoco:Engine',
         kwargs={'config': gym_cfg})

    kwargs=dict(
            ac_kwargs=dict(hidden_sizes=(256, 256),),
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            save_freq=save_freq,
            target_kl=target_kl,
            #cost_lim=cost_lim,
            seed=args.seed,
            logger_kwargs=logger_kwargs,
            )

    agent = SPARGym(env, cfg, args, **kwargs)
    agent.train()


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Parser')

    parser.add_argument("--yaml", type=str, default='', help="yaml config file")
    parser.add_argument("--safety_margin", type=str, default='4.2', help="safety margin")

    parser.add_argument('--algo', type=str, default='ppo')
    parser.add_argument('--robot', type=str, default='Car')
    parser.add_argument('--task', type=str, default='Goal1')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='')
    parser.add_argument('--cpu', type=int, default=1, help="Number of CPUs to use for training")

    opts = parser.parse_args()

    # load configuration file
    yaml = YAML()
    params = yaml.load(open(opts.yaml))

    exp_name = args.exp_name if not(opts.exp_name=='') else None

    params = params['sac_kwargs']
    main(exp_name, params, opts)

