import os
import argparse
import time
from shutil import copyfile
import matplotlib
matplotlib.use('Agg')

# AGENT
from agent.agent_sac import AgentSAC
from agent.agent_sac_reptile import AgentSACReptile

# ENV
from panda_gym.vec_env.vec_env import VecEnvPush, VecEnvPushTool, VecEnvLift, VecEnvHammer, VecEnvSweep
from util.vec_env import make_vec_envs
from util.yaml import load_config


def main(config_file, config_dict):
    # Config
    CONFIG_ENV = config_dict['environment']
    CONFIG_TRAINING = config_dict['training']
    CONFIG_ARCH = config_dict['arch']
    os.makedirs(CONFIG_TRAINING.OUT_FOLDER, exist_ok=True)
    copyfile(config_file,
             os.path.join(CONFIG_TRAINING.OUT_FOLDER, 'config.yaml'))

    # Reuse some config
    CONFIG_ARCH.IMG_H = CONFIG_ENV.CAMERA['img_h']
    CONFIG_ARCH.IMG_W = CONFIG_ENV.CAMERA['img_w']
    CONFIG_ARCH.ACTION_DIM = CONFIG_ENV.ACTION_DIM
    if not CONFIG_ENV.USE_LANG:
        CONFIG_ENV.LANG_DIM = 0
    CONFIG_ARCH.LANG_DIM = CONFIG_ENV.LANG_DIM

    # Environment
    print("\n== Environment Information ==")
    if CONFIG_ENV.ENV_NAME == 'Push-v0':
        vec_env_type = VecEnvPush
    elif CONFIG_ENV.ENV_NAME == 'PushTool-v0':
        vec_env_type = VecEnvPushTool
    elif CONFIG_ENV.ENV_NAME == 'Lift-v0':
        vec_env_type = VecEnvLift
    elif CONFIG_ENV.ENV_NAME == 'Hammer-v0':
        vec_env_type = VecEnvHammer
    elif CONFIG_ENV.ENV_NAME == 'Sweep-v0':
        vec_env_type = VecEnvSweep
    else:
        raise NotImplementedError
    venv = make_vec_envs(
        env_name=CONFIG_ENV.ENV_NAME,
        seed=CONFIG_TRAINING.SEED,
        num_processes=CONFIG_TRAINING.NUM_CPUS,
        cpu_offset=CONFIG_TRAINING.CPU_OFFSET,
        device=CONFIG_TRAINING.DEVICE,
        config_env=CONFIG_ENV,
        vec_env_type=vec_env_type,
        renders=CONFIG_ENV.RENDER,
        camera_params=CONFIG_ENV.CAMERA,
    )

    # Agent
    print("\n== Agent Information ==")
    if CONFIG_TRAINING.AGENT_NAME == 'AgentSAC':
        agent_class = AgentSAC
    elif CONFIG_TRAINING.AGENT_NAME == 'AgentSACReptile':
        agent_class = AgentSACReptile
    else:
        raise NotImplementedError
    agent = agent_class(venv, CONFIG_TRAINING, CONFIG_ARCH, CONFIG_ENV)
    print('\nTotal parameters in policy: {}'.format(
        sum(p.numel() for p in agent.learner.parameters()
            if p.requires_grad)))
    print("We want to use: {}, and Agent uses: {}".format(
        CONFIG_TRAINING.DEVICE, agent.learner.device))

    # Learn
    start_time = time.time()
    if CONFIG_TRAINING.EVAL:
        print("\n== Evaluating ==")
        agent.evaluate()
    else:
        print("\n== Learning ==")
        agent.learn()
    print('\nTime used: {:.1f}'.format(time.time() - start_time))


if __name__ == "__main__":
    import time
    parser = argparse.ArgumentParser()
    parser.add_argument("-cf",
                        "--config_file",
                        help="config file path",
                        type=str)
    args = parser.parse_args()
    config_dict = load_config(args.config_file)
    main(args.config_file, config_dict)
