import json
import sys
import os
from onpolicy.debug import debug_print
import wandb
import socket
import setproctitle
import numpy as np
from pathlib import Path
import torch
from onpolicy.config import get_config
from onpolicy.envs.robomimic.robomimic_lowdim import RobomimicLowdimWrapper
from onpolicy.envs.env_wrappers import SubprocVecEnv, DummyVecEnv
from onpolicy.algorithms.diffusion_ac.datasets import Robomimic_Dataset
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils

"""Train script for D4RL."""

low_dim_keys= ['robot0_eef_pos',
                    'robot0_eef_quat',
                    'robot0_gripper_qpos',
                    # "robot1_eef_pos",
                    # "robot1_eef_quat",
                    # "robot1_gripper_qpos",
                    'object']

low_dim_keys1= ['robot0_eef_pos',
                    'robot0_eef_quat',
                    'robot0_gripper_qpos',
                    "robot1_eef_pos",
                    "robot1_eef_quat",
                    "robot1_gripper_qpos",
                    'object']

def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            obs_modality_dict = {
                "low_dim": low_dim_keys,
                "rgb": None
            }
            if obs_modality_dict["rgb"] is None:
                obs_modality_dict.pop("rgb")
            ObsUtils.initialize_obs_modality_mapping_from_dict(obs_modality_dict)
            if all_args.render_offscreen or all_args.use_image_obs:
                os.environ["MUJOCO_GL"] = "egl"
            with open(all_args.robomimic_env_cfg_path, "r") as f:
                env_meta = json.load(f)
            env_meta["reward_shaping"] = all_args.reward_shaping
            # print(all_args.render)
            # print(env_meta)
            # print(env_meta, all_args.render, all_args.render_offscreen, all_args.use_image_obs)
            env = EnvUtils.create_env_from_metadata(
                env_meta=env_meta,
                render=all_args.render,
                # only way to not show collision geometry is to enable render_offscreen, which uses a lot of RAM.
                render_offscreen=all_args.render_offscreen,
                use_image_obs=all_args.use_image_obs,
                # render_gpu_device_id=0,
            )
            # Robosuite's hard reset causes excessive memory consumption.
            # Disabled to run more envs.
            # https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
            env.env.hard_reset = False
            # if all_args.env_name == "D4RL":
            env = RobomimicLowdimWrapper(env, all_args, low_dim_keys=low_dim_keys, normalization_path=all_args.normalization_path)
            # else:
            #     print("Can not support the " +
            #           all_args.env_name + "environment.")
            #     raise NotImplementedError
            env.seed(all_args.seed + rank * 1000)
            return env
        return init_env
    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])


def make_eval_env(all_args):
    def get_env_fn(rank):
        def init_env():
            obs_modality_dict = {
                "low_dim": low_dim_keys,
                "rgb": None
            }
            if obs_modality_dict["rgb"] is None:
                obs_modality_dict.pop("rgb")
            ObsUtils.initialize_obs_modality_mapping_from_dict(obs_modality_dict)
            if all_args.render_offscreen or all_args.use_image_obs:
                os.environ["MUJOCO_GL"] = "egl"
            with open(all_args.robomimic_env_cfg_path, "r") as f:
                env_meta = json.load(f)
            env_meta["reward_shaping"] = all_args.reward_shaping
            # print(all_args.render)
            # print(env_meta)
            env = EnvUtils.create_env_from_metadata(
                env_meta=env_meta,
                render=all_args.render,
                # only way to not show collision geometry is to enable render_offscreen, which uses a lot of RAM.
                render_offscreen=all_args.render_offscreen,
                use_image_obs=all_args.use_image_obs,
                # render_gpu_device_id=0,
            )
            # Robosuite's hard reset causes excessive memory consumption.
            # Disabled to run more envs.
            # https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
            env.env.hard_reset = False
            # if all_args.env_name == "D4RL":
            env = RobomimicLowdimWrapper(env, all_args, low_dim_keys=low_dim_keys, normalization_path=all_args.normalization_path)
            # else:
            #     print("Can not support the " +
            #           all_args.env_name + "environment.")
            #     raise NotImplementedError
            env.seed(all_args.seed + rank * 1000)
            return env
        return init_env
    if all_args.n_eval_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])


def parse_args(args, parser):
    parser.add_argument('--scenario_name', type=str,
                        default='maze2d-umaze-v1', help="Which scenario to run on")
    parser.add_argument('--dataset_path', type=str, 
                        default="", help="path to the dataset")

    # parser.add_argument('--norm_obs', default=False, action='store_true')
    # parser.add_argument('--norm_reward', default=False, action='store_true')
    # parser.add_argument('--clip_reward', default=100000., type=float)
    # parser.add_argument('--clip_obs', default=100000., type=float)

    parser.add_argument('--n_timesteps', type=int, default=5)
    parser.add_argument('--beta_schedule', type=str, default='vp')
    parser.add_argument('--predict_epsilon', action='store_true', default=False)
    parser.add_argument('--t_dim', type=int, default=64)
    parser.add_argument('--value_dim', type=int, default=1)
    parser.add_argument('--num_agents', type=int, default=1)
    parser.add_argument('--rnum_agents', type=int, default=1)
    parser.add_argument('--use_latent_actions', action='store_true', default=False)
    parser.add_argument('--bc_loss_coef', type=float, default=0.)
    parser.add_argument('--ppo_loss_coef', type=float, default=1.)
    parser.add_argument('--act_with_small_std', action='store_true', default=False)
    parser.add_argument('--aug_latent_actions', action='store_true', default=False)
    # parser.add_argument('--recompute_adv', action='store_true', default=False)
    parser.add_argument('--sep_bc_phase', action='store_true', default=False, help="use a seperate behavior cloning phase")
    parser.add_argument('--bc_epoch', type=int, default=0, help='number of behavior cloning epochs if using seperate BC phase')

    parser.add_argument('--warmup_steps', default=0, type=int)
    parser.add_argument('--group_name',default="debug",type=str)
    parser.add_argument('--bc_mode_control',default=1, type=int)
    
    parser.add_argument('--robomimic_env_cfg_path', type=str, default="", help="path to the robomimic environment config")
    parser.add_argument('--render_offscreen', action='store_true', default=False)
    parser.add_argument('--use_image_obs', action='store_true', default=False)
    parser.add_argument('--reward_shaping', action='store_true', default=False)
    parser.add_argument('--render', type=bool, default=False)
    parser.add_argument('--max_episode_length', type=int, default=512)
    
    all_args = parser.parse_known_args(args)[0]

    return all_args

def make_dummy_env(scenario_name):
    import gym
    return gym.make(scenario_name)

def main(args):
    parser = get_config()
    all_args = parse_args(args, parser)

    # assert (all_args.algorithm_name in ["diff-mappo"]), ("train_d4rl_diff.py only supports algorithm_name in ['diff-mappo']")
    all_args.use_naive_recurrent_policy = False
    if all_args.scenario_name == 'transport':
        all_args.low_dim_keys = low_dim_keys1

    # cuda
    assert(torch.cuda.is_available() and all_args.cuda), "Cuda is not available. Please check your CUDA installation."
    if all_args.cuda and torch.cuda.is_available():
        print("choose to use gpu...")
        device = torch.device("cuda:0")
        torch.set_num_threads(all_args.n_training_threads)
        if all_args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        print("choose to use cpu...")
        device = torch.device("cpu")
        torch.set_num_threads(all_args.n_training_threads)

    # run dir
    run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[
                   0] + "/results") / all_args.env_name / all_args.scenario_name / all_args.algorithm_name / all_args.experiment_name
    if not run_dir.exists():
        os.makedirs(str(run_dir))

    setproctitle.setproctitle(str(all_args.algorithm_name) + "-" + \
        str(all_args.scenario_name) + "-" + str(all_args.experiment_name) + "@" + str(all_args.user_name))

    train_dataset = Robomimic_Dataset(all_args.scenario_name, 1, all_args.dataset_path, all_args.act_step, split="train", normalization_path=all_args.normalization_path) if all_args.clone_episodes >0 else None
    val_dataset = Robomimic_Dataset(all_args.scenario_name, 1, all_args.dataset_path, all_args.act_step, split="val", normalization_path=all_args.normalization_path) if all_args.clone_episodes >0 else None


    # fix gamma and gae_lambda
    # all_args.gamma = all_args.gamma ** (1 / all_args.n_timesteps)
    # all_args.gae_lambda = all_args.gae_lambda #** (1 / all_args.n_timesteps)

    # seed
    torch.manual_seed(all_args.seed)
    torch.cuda.manual_seed_all(all_args.seed)
    np.random.seed(all_args.seed)

    # env init
    envs = make_train_env(all_args)
    eval_envs = make_eval_env(all_args) if all_args.use_eval else None

    # wandb
    if all_args.use_wandb:
        run = wandb.init(config=all_args,
                         project=all_args.env_name,
                         entity=all_args.user_name,
                         notes=socket.gethostname(),
                         name=str(all_args.algorithm_name) + "_" +
                         str(all_args.experiment_name) +
                         "_seed" + str(all_args.seed),
                         group=all_args.scenario_name,
                         dir=str(run_dir),
                         job_type="training",
                         reinit=True)
    else:
        if not run_dir.exists():
            curr_run = 'run1'
        else:
            exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if str(folder.name).startswith('run')]
            if len(exst_run_nums) == 0:
                curr_run = 'run1'
            else:
                curr_run = 'run%i' % (max(exst_run_nums) + 1)
        run_dir = run_dir / curr_run
        if not run_dir.exists():
            os.makedirs(str(run_dir))

    config = {
        "all_args": all_args,
        "envs": envs,
        "eval_envs": eval_envs,
        "num_agents": 1,
        "device": device,
        "run_dir": run_dir,
        "train_dataset": train_dataset, 
        "val_dataset": val_dataset, 
    }

    # run experiments
    if all_args.share_policy:
        from onpolicy.runner.shared.robomimic_runner import Robomimic_DiffRunner as Runner
    else:
        from onpolicy.runner.separated.robomimic_runner import Robomimic_DiffRunner as Runner

    print(device)

    runner = Runner(config)
    runner.run()
    
    # post process
    envs.close()
    if all_args.use_eval and eval_envs is not envs:
        eval_envs.close()

    if all_args.use_wandb:
        run.finish()
    else:
        runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json'))
        runner.writter.close()


if __name__ == "__main__":
    main(sys.argv[1:])