from metadrive.envs.safe_metadrive_env import SafemetadriveEnv

from safe_rl.pg.algos import cpo
from safe_rl.utils.mpi_tools import mpi_fork, terminate
from safe_rl.utils.run_utils import setup_logger_kwargs


class WrapperEnv(SafemetadriveEnv):

    def step(self, actions):
        o, r, d, i = super(WrapperEnv, self).step(actions[0])
        return o, r, d, i


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--num_cpus', '-l', type=int, default=0)
    parser.add_argument('--seed', '-s', type=int, required=True)
    ars = parser.parse_args()
    exp_name = "CPO_metadrive_{}".format(ars.seed)
    config = {
        "num_cpus": ars.num_cpus,
        "exp_name": exp_name,
        "env_config": dict(),
        "saferl_config": {
            "use_ipd": False,
            "use_ctnb": False,
            "use_ipd_soft": False,
            "cost_threshold": 2,
            "max_ep_len": 1000,
        },

        # local runner config
        "seed": ars.seed,
        "tmp_file": "deleteme.json",
        "num_steps": 200_0000
    }
    # Hyperparameters
    assert "num_steps" in config

    num_steps = config["num_steps"]
    steps_per_epoch = 4000
    bind_to_core = False,
    data_dir = "./"
    epochs = int(num_steps / steps_per_epoch)

    # Add more saves
    save_freq = 1

    # save_freq = 50
    target_kl = 0.01

    # Fork for parallelizing
    if config["num_cpus"] > 1:
        mpi_fork(config["num_cpus"], bind_to_core)

    # Prepare Logger
    logger_kwargs = setup_logger_kwargs(
        exp_name, config["seed"], data_dir=data_dir
    )

    env_fn = lambda: WrapperEnv(config["env_config"])

    cpo(
        env_fn=env_fn,
        ac_kwargs=dict(hidden_sizes=(256, 256), ),
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        save_freq=save_freq,
        target_kl=target_kl,
        # cost_lim=cost_lim,
        seed=config["seed"],
        logger_kwargs=logger_kwargs,
        tmp_file=config["tmp_file"],
        saferl_config=config.get("saferl_config", dict()),
    )
    print(
        "We have successfully finished the algorithm: {}, env {}, exp {}."
        "Now terminate.".format("CPo", "metadrive", exp_name)
    )
    terminate()
    print("This part should never run.")
