import argparse
import os
import sys

import torch

from gym.envs.registration import register

from divmorph.algos.ppo.ppo import PPO
from divmorph.algos.ppo.svd_ppo import PPO_SVD
from divmorph.config import cfg
from divmorph.config import dump_cfg
from divmorph.utils import file as fu
from divmorph.utils import sample as su
from divmorph.utils import sweep as swu


def set_cfg_options():
    calculate_max_iters()
    maybe_infer_walkers()
    calculate_max_limbs_joints()


def calculate_max_limbs_joints():
    if cfg.ENV_NAME == "Unimal-v0":

        num_joints, num_limbs = [], []

        metadata_paths = []
        for agent in cfg.ENV.WALKERS:
            metadata_paths.append(os.path.join(
                cfg.ENV.WALKER_DIR, "metadata", "{}.json".format(agent)
            ))

        for metadata_path in metadata_paths:
            metadata = fu.load_json(metadata_path)
            num_joints.append(metadata["dof"])
            num_limbs.append(metadata["num_limbs"] + 1)
            
        cfg.MODEL.MAX_JOINTS = max(num_joints) + 1
        cfg.MODEL.MAX_LIMBS = max(num_limbs) + 1

        print(f"MAX_JOINTS: {cfg.MODEL.MAX_JOINTS}, MAX_LIMB: {cfg.MODEL.MAX_LIMBS}")
    
    else:
        raise NotImplementedError


def calculate_max_iters():
    cfg.PPO.MAX_ITERS = (
        int(cfg.PPO.MAX_STATE_ACTION_PAIRS) // cfg.PPO.TIMESTEPS // cfg.PPO.NUM_ENVS
    )
    cfg.PPO.EARLY_EXIT_MAX_ITERS = (
        int(cfg.PPO.EARLY_EXIT_STATE_ACTION_PAIRS) // cfg.PPO.TIMESTEPS // cfg.PPO.NUM_ENVS
    )


def maybe_infer_walkers():
    if cfg.ENV_NAME not in ["Unimal-v0", "Modular-v0"]:
        return

    if len(cfg.ENV.WALKERS):
        return

    cfg.ENV.WALKERS = [
        xml_file.split(".")[0]
        for xml_file in os.listdir(os.path.join(cfg.ENV.WALKER_DIR, "xml"))
    ]


def register_modular_envs():
    for agent in cfg.ENV.WALKERS:
        xml = os.path.join(cfg.ENV.WALKER_DIR, 'xml', agent + '.xml')
        params = {"xml": os.path.abspath(xml)}
        try:
            register(
                id=f"{agent}-v0",
                max_episode_steps=1000,
                entry_point=f"modular.{agent}:make_env",
                kwargs=params,
            )
        except:
            continue
        

def parse_args():
    parser = argparse.ArgumentParser(description="Train a RL agent")
    parser.add_argument(
        "--cfg", dest="cfg_file", help="Config file", required=True, type=str
    )
    parser.add_argument(
        "--no_context_in_state", action="store_true"
    )
    parser.add_argument(
        "opts",
        help="See morphology/core/config.py for all options",
        default=None,
        nargs=argparse.REMAINDER,
    )
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    return parser.parse_args()


def ppo_train():
    su.set_seed(cfg.RNG_SEED)
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
        torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC

    torch.set_num_threads(1)

    if cfg.MODEL.TRANSFORMER.SVD:
        print('Use PPO_SVD')
        PPOTrainer = PPO_SVD()
    else:
        print('Use PPO')
        PPOTrainer = PPO()
    PPOTrainer.train()
    PPOTrainer.save_model(-1)


def main():
    args = parse_args()

    cfg.merge_from_file(args.cfg_file)
    cfg.merge_from_list(args.opts)

    if args.no_context_in_state:
        obs_type = [
            "body_xpos", "body_xvelp", "body_xvelr", "body_xquat",
            "qpos", "qvel",
        ]
        ob_opts = ["MODEL.PROPRIOCEPTIVE_OBS_TYPES", obs_type]
        cfg.merge_from_list(ob_opts)

    set_cfg_options()
    os.makedirs(cfg.OUT_DIR, exist_ok=True)

    dump_cfg()
    ppo_train()


if __name__ == "__main__":
    main()
