#!/usr/bin/env python
import sys
import os
import yaml
import wandb
import socket
import setproctitle
import numpy as np
from argparse import Namespace
from pathlib import Path

import torch

from hsp.config import get_config

from hsp.envs.overcooked.Overcooked_Env import Overcooked, OvercookedEnv
from hsp.envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv, ChooseSubprocVecEnv, ChooseDummyVecEnv

def make_train_env(all_args, run_dir):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "Overcooked":
                env = Overcooked(all_args, run_dir)
            else:
                print("Can not support the " +
                      all_args.env_name + "environment.")
                raise NotImplementedError
            env.seed(all_args.seed + rank * 1000)
            return env
        return init_env
    if all_args.n_rollout_threads == 1:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])


def make_eval_env(all_args, run_dir):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "Overcooked":
                env = Overcooked(all_args, run_dir)
            else:
                print("Can not support the " +
                      all_args.env_name + "environment.")
                raise NotImplementedError
            env.seed(all_args.seed * 50000 + rank * 10000)
            return env
        return init_env
    if all_args.n_eval_rollout_threads == 1:
        return ChooseSubprocVecEnv([get_env_fn(0)])
    else:
        return ChooseDummyVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])

def parse_args(args, parser):
    parser.add_argument("--layout_name", type=str, default='cramped_room', help="Name of Submap, 40+ in choice. See /src/data/layouts/.")
    parser.add_argument('--num_agents', type=int, default=1, help="number of players")
    parser.add_argument("--initial_reward_shaping_factor", type=float, default=1.0, help="Shaping factor of potential dense reward.")
    parser.add_argument("--reward_shaping_factor", type=float, default=1.0, help="Shaping factor of potential dense reward.")
    parser.add_argument("--reward_shaping_horizon", type=int, default=2.5e6, help="Shaping factor of potential dense reward.")
    parser.add_argument("--use_phi", default=False, action='store_true', help="While existing other agent like planning or human model, use an index to fix the main RL-policy agent.")
    parser.add_argument("--use_hsp", default=False, action='store_true')   
    parser.add_argument("--random_index", default=False, action='store_true')   
    parser.add_argument("--use_agent_policy_id", default=False, action='store_true', help="Add policy id into share obs, default False")
    parser.add_argument("--w0", type=str, default='1,1,1,1', help="Weight vector of dense reward 0 in overcooked env.")
    parser.add_argument("--w1", type=str, default='1,1,1,1', help="Weight vector of dense reward 1 in overcooked env.") 

    # population
    parser.add_argument("--population_yaml_path", type=str, help="Path to yaml file that stores the population info.")

    #traj 
    parser.add_argument("--traj_entropy_alpha", type=float, default=0.1, help="Weight for population entropy reward.")
    parser.add_argument("--traj_gamma", type=float, default=0.5, help="Weight for population entropy reward.")
    parser.add_argument("--traj_stage", type=int, default=1 ,help="Stages of Traj training. 1 for Maximum-Entropy PBT. 2 for FCP-like training.")
    parser.add_argument("--traj_use_prioritized_sampling", default=False, action='store_true', help="Use prioritized sampling in Traj stage 2.")
    parser.add_argument("--traj_prioritized_alpha", type=float, default=3.0, help="Alpha used in softing prioritized sampling probability.")

    # population
    parser.add_argument("--population_size", type=int, default=5, help="Population size involved in training.")
    parser.add_argument("--traj_agent_name", type=str, required=True, help="Name of final policy.")
    
    # train and eval batching
    parser.add_argument("--train_env_batch", type=int, default=1, help="Number of parallel threads a policy holds")
    parser.add_argument("--eval_env_batch", type=int, default=1, help="Number of parallel threads a policy holds")

    # fixed policy actions inside env threads
    parser.add_argument("--use_policy_in_env", default=False, action="store_true", help="Use loaded policy to move in env threads.")
    parser.add_argument("--predict_other_shaped_info", default=False, action='store_true', help="Predict other agent's shaped info within a short horizon, default False")
    parser.add_argument("--predict_shaped_info_horizon", default=50, type=int, help="Horizon for shaped info target, default 50")
    parser.add_argument("--predict_shaped_info_event_count", default=10, type=int, help="Event count for shaped info target, default 10")

    all_args = parser.parse_known_args(args)[0]

    return all_args

def main(args):
    parser = get_config()
    all_args = parse_args(args, parser)

    assert all_args.algorithm_name == "traj"

    # cuda
    if all_args.cuda and torch.cuda.is_available():
        print("choose to use gpu...")
        device = torch.device("cuda:0")
        torch.set_num_threads(all_args.n_training_threads)
        if all_args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        print("choose to use cpu...")
        device = torch.device("cpu")
        torch.set_num_threads(all_args.n_training_threads)

    # run dir
    run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[
                   0] + "/results") / all_args.env_name / all_args.layout_name / all_args.algorithm_name / all_args.experiment_name
    if not run_dir.exists():
        os.makedirs(str(run_dir))

    # wandb
    if all_args.use_wandb:
        run = wandb.init(config=all_args,
                         project=all_args.env_name,
                         entity=all_args.wandb_name,
                         notes=socket.gethostname(),
                         name=str(all_args.algorithm_name) + "_" +
                         str(all_args.experiment_name) +
                         "_seed" + str(all_args.seed),
                         group=all_args.layout_name,
                         dir=str(run_dir),
                         job_type="training",
                         reinit=True,
                         tags=all_args.wandb_tags)
    else:
        if not run_dir.exists():
            curr_run = 'run1'
        else:
            exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if str(folder.name).startswith('run')]
            if len(exst_run_nums) == 0:
                curr_run = 'run1'
            else:
                curr_run = 'run%i' % (max(exst_run_nums) + 1)
        run_dir = run_dir / curr_run
        if not run_dir.exists():
            os.makedirs(str(run_dir))

    setproctitle.setproctitle(str(all_args.algorithm_name) + "-" + \
        str(all_args.env_name) + "-" + str(all_args.experiment_name) + "@" + str(all_args.user_name))

    # seed
    torch.manual_seed(all_args.seed)
    torch.cuda.manual_seed_all(all_args.seed)
    np.random.seed(all_args.seed)

    # env init
    envs = make_train_env(all_args, run_dir)
    eval_envs = make_eval_env(all_args, run_dir) if all_args.use_eval else None
    num_agents = all_args.num_agents

    config = {
        "all_args": all_args,
        "envs": envs,
        "eval_envs": eval_envs,
        "num_agents": num_agents,
        "device": device,
        "run_dir": run_dir
    }

    # run experiments
    # if all_args.share_policy:
    #     from hsp.runner.shared.overcooked_runner import OvercookedRunner as Runner
    # else:
    #     from hsp.runner.separated.overcooked_runner import MPERunner as Runner
    from hsp.runner.shared.overcooked_runner import OvercookedRunner as Runner

    runner = Runner(config)
    
    # load population
    print("population_yaml_path: ",all_args.population_yaml_path)

    #  override policy config
    population_config = yaml.load(open(all_args.population_yaml_path))
    override_policy_config = {}
    agent_name = all_args.traj_agent_name
    override_policy_config[agent_name] = (Namespace(use_agent_policy_id=all_args.use_agent_policy_id), *runner.policy_config[1:])
    for policy_name in population_config:
        if policy_name != agent_name:
            override_policy_config[policy_name] = (None, None, runner.policy_config[2], None) # only override share_obs_space

    runner.policy.load_population(all_args.population_yaml_path, evaluation=False, override_policy_config=override_policy_config)
    runner.trainer.init_population()

    runner.train_traj()
    
    # post process
    envs.close()
    if all_args.use_eval and eval_envs is not envs:
        eval_envs.close()

    if all_args.use_wandb:
        run.finish()
    else:
        runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json'))
        runner.writter.close()


if __name__ == "__main__":
    main(sys.argv[1:])