#!/usr/bin/env python
import re
import sys
import os
import wandb
import socket
import setproctitle
import numpy as np
from pathlib import Path
import torch
from code_ptmc_mappo.config import get_config
from code_ptmc_mappo.envs.env_wrappers import PreySubprocVecEnv, ShareDummyVecEnv
import argparse

"""Train script for stag_hunt."""

def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            from code_ptmc_mappo.envs.stag_hunt.staghunt_env import StagHunt
            env = StagHunt(all_args)
            return env
        return init_env
    if all_args.n_rollout_threads == 1:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        return PreySubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])

def make_eval_env(all_args):
    def get_env_fn(rank):
        def init_env():
            from code_ptmc_mappo.envs.stag_hunt.staghunt_env import StagHunt
            env = StagHunt(all_args)
            return env
        return init_env
    if all_args.n_eval_rollout_threads == 1:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        return PreySubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])

def str2list(v):
    if isinstance(v, list):
        return v
    if isinstance(v, str):
        v = v.strip("[]")
        return [int(x) for x in v.split(',')]
    raise argparse.ArgumentTypeError("Input must be a comma-separated list of ints.")

def parse_args(args, parser):
    parser.add_argument('--n_agents', type=int, default=7)
    parser.add_argument('--n_stags', type=int, default=12)
    parser.add_argument('--n_hare', type=int, default=0)
    parser.add_argument("--agent_obs", type=str2list, default=[2,2])
    parser.add_argument("--world_shape", type=str2list, default=[20,20]) # --world_shape "[20,20]" 注意要加引号
    parser.add_argument("--capture_conditions", type=str2list, default=[0,1])
    parser.add_argument("--capture_action_conditions", type=str2list, default=[2,1])
    parser.add_argument("--batch_size", type=int, default=None)
    all_args = parser.parse_known_args(args)[0]
    return all_args

def main(args):
    parser = get_config()
    all_args = parse_args(args, parser)

    if all_args.algorithm_name == "rmappo":
        all_args.use_recurrent_policy = True
        all_args.use_naive_recurrent_policy = False
    elif (all_args.algorithm_name == "mappo" or all_args.algorithm_name == "ptmc"):
        all_args.use_naive_recurrent_policy = False
    elif all_args.algorithm_name == "ippo":
        all_args.use_centralized_V = False
    else:
        raise NotImplementedError

    # cuda
    if all_args.cuda and torch.cuda.is_available():
        print("choose to use gpu...")
        device = torch.device("cuda:0")
        torch.set_num_threads(all_args.n_training_threads)
        if all_args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        print("choose to use cpu...")
        device = torch.device("cpu")
        torch.set_num_threads(all_args.n_training_threads)

    run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[
                       0] + "/results") / all_args.env_name / all_args.algorithm_name / all_args.experiment_name
    if not run_dir.exists():
        os.makedirs(str(run_dir))

    if all_args.use_wandb:
        run = wandb.init(config=all_args,
                         project=all_args.env_name,
                         entity=all_args.user_name,
                         notes=socket.gethostname(),
                         name=str(all_args.algorithm_name) + "_" +
                              str(all_args.experiment_name) + "_" +
                              str(all_args.units) +
                              "_seed" + str(all_args.seed),

                         dir=str(run_dir),
                         job_type="training",
                         reinit=True)
        all_args = wandb.config # for wandb sweep
    else:
        if not run_dir.exists():
            curr_run = 'run1'
        else:
            exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if
                             str(folder.name).startswith('run')]
            if len(exst_run_nums) == 0:
                curr_run = 'run1'
            else:
                curr_run = 'run%i' % (max(exst_run_nums) + 1)
        run_dir = run_dir / curr_run
        if not run_dir.exists():
            os.makedirs(str(run_dir))

    if all_args.algorithm_name == "ptmc":
        match = re.search(r"(models\d+)", all_args.pretr_model_dir)
        chosen_model = match.group(1) if match else "models_unknown"
        setproctitle.setproctitle(
            str(all_args.algorithm_name) + "-" + str(all_args.env_name) + "-seed:" +  str(all_args.seed)
                +"-" + "-pretr_" + chosen_model + "--@" + str(all_args.user_name))
    else:
        setproctitle.setproctitle(
            str(all_args.algorithm_name) + "-" + str(all_args.env_name) + "-seed:" +  str(all_args.seed)
                +"-" + "--@" + str(all_args.user_name) )

    # seed
    torch.manual_seed(all_args.seed)
    torch.cuda.manual_seed_all(all_args.seed)
    np.random.seed(all_args.seed)

    # env
    envs = make_train_env(all_args)
    eval_envs = make_eval_env(all_args) if all_args.use_eval else None
    num_agents = all_args.n_agents

    config = {
        "all_args": all_args,
        "envs": envs,
        "eval_envs": eval_envs,
        "num_agents": num_agents,
        "device": device,
        "run_dir": run_dir
    }

    # run experiments
    from code_ptmc_mappo.runner.shared.stag_runner import StugHuntRunner as Runner

    runner = Runner(config)
    runner.run()

    # post process
    envs.close()
    if all_args.use_eval and eval_envs is not envs:
        eval_envs.close()

    if all_args.use_wandb:
        run.finish()
    else:
        runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json'))
        runner.writter.close()

if __name__ == "__main__":
    main(sys.argv[1:])
