#!/usr/bin/env python
import sys
import os
# import wandb
import socket
import setproctitle
import numpy as np
import random
from pathlib import Path
import torch
from onpolicy.config import get_config

from onpolicy.envs.starcraft2.smac_maps import get_map_params
from onpolicy.envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv

from smacv2_map_config import distribution_config_dict
from smacv2.env import StarCraft2Env
from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper
from gym.spaces import Discrete

class My_smacv2_wrapper(StarCraftCapabilityEnvWrapper):
    def __init__(self, **kwargs):
        super(My_smacv2_wrapper, self).__init__(**kwargs)
        # self.observation_space = self.get_obs_size()
        # self.share_observation_space = self.get_state_size()
        # self.action_space = self.n_actions
        self.action_space = []
        self.observation_space = []
        self.share_observation_space = []
        for i in range(self.n_agents):
            self.action_space.append(Discrete(self.n_actions))
            self.observation_space.append(self.get_obs_size())
            self.share_observation_space.append(self.get_state_size())
    def reset(self):
        # try:
        #     reset_config = {}
        #     for distribution in self.env_key_to_distribution_map.values():
        #         reset_config = {**reset_config, **distribution.generate()}
        #
        #     return self.env.reset(reset_config)
        # except CannotResetException as cre:
        #     # just retry
        #     self.reset()
        super(My_smacv2_wrapper, self).reset()
        obs = self.get_obs()
        share_obs = self.get_state()
        available_actions = []
        for i in range(self.n_agents):
            available_actions.append(self.get_avail_agent_actions(i))
        return obs, share_obs,available_actions
    def step(self,actions):
        reward, terminated, info = self.env.step(actions)
        dones = np.ones((self.n_agents), dtype=bool)
        dones *= terminated

        info['won'] = self.env.win_counted

        infos = []
        for i in range(self.n_agents):
            infos.append(info)

        available_actions = []
        for i in range(self.n_agents):
            available_actions.append(self.get_avail_agent_actions(i))


        local_obs = self.get_obs()
        share_obs = self.get_state()
        # print(f'in single_env step = {dones}, shape = {np.shape(dones)}')
        return local_obs, share_obs, reward, dones, infos, available_actions




"""Train script for SMAC."""

def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "smacv2":
                base_distribution_config = distribution_config_dict[all_args.map_name]
                base_distribution_config['n_units'] = all_args.v2_n_units
                base_distribution_config['n_enemies'] = all_args.v2_n_enemies
                env = My_smacv2_wrapper(
                    capability_config=base_distribution_config,
                    map_name=all_args.map_name,
                    debug=True,
                    conic_fov=False,
                    obs_own_pos=True,
                    use_unit_ranges=True,
                    min_attack_range=2,
                    seed = all_args.seed + rank * 1000,
                )
            else:
                print("Can not support the " + all_args.env_name + "environment.")
                raise NotImplementedError
            # env.seed(all_args.seed + rank * 1000)
            return env

        return init_env
    if all_args.multi_rollout:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        if all_args.n_rollout_threads == 1:
            return ShareDummyVecEnv([get_env_fn(0)])
        else:
            return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])


def make_eval_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "smacv2":
                base_distribution_config = distribution_config_dict[all_args.map_name]
                base_distribution_config['n_units'] = all_args.v2_n_units
                base_distribution_config['n_enemies'] = all_args.v2_n_enemies
                env = My_smacv2_wrapper(
                    capability_config=base_distribution_config,
                    map_name=all_args.map_name,
                    debug=True,
                    conic_fov=False,
                    obs_own_pos=True,
                    use_unit_ranges=True,
                    min_attack_range=2,
                    seed=all_args.seed,
                )
            else:
                print("Can not support the " + all_args.env_name + "environment.")
                raise NotImplementedError
            # env.seed(all_args.seed)
            return env

        return init_env

    if all_args.multi_rollout:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        if all_args.n_eval_rollout_threads == 1:
            return ShareDummyVecEnv([get_env_fn(0)])
        else:
            return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])


def parse_args(args, parser):

    parser.add_argument('--map_name', type=str, default='10gen_terran',
                        help="Which smac map to run on")
    parser.add_argument('--v2_n_units', type=int, default=5)
    parser.add_argument('--v2_n_enemies', type=int, default=5)
    parser.add_argument("--add_move_state", action='store_true', default=False)
    parser.add_argument("--add_local_obs", action='store_true', default=False)
    parser.add_argument("--add_distance_state", action='store_true', default=False)
    parser.add_argument("--add_enemy_action_state", action='store_true', default=False)
    parser.add_argument("--add_agent_id", action='store_true', default=False)
    parser.add_argument("--add_visible_state", action='store_true', default=False)
    parser.add_argument("--add_xy_state", action='store_true', default=False)
    parser.add_argument("--use_state_agent", action='store_true', default=False)
    parser.add_argument("--use_mustalive", action='store_false', default=True)
    parser.add_argument("--add_center_xy", action='store_true', default=False)

    all_args = parser.parse_known_args(args)[0]

    return all_args


def main(args):
    parser = get_config()
    all_args = parse_args(args, parser)
    all_args.env_name = 'smacv2'

    if all_args.algorithm_name == "rmappo":
        assert (all_args.use_recurrent_policy or all_args.use_naive_recurrent_policy), ("check recurrent policy!")
    elif all_args.algorithm_name == "mappo":
        assert (all_args.use_recurrent_policy == False and all_args.use_naive_recurrent_policy == False), (
            "check recurrent policy!")
    else:
        raise NotImplementedError

    # cuda
    if all_args.cuda and torch.cuda.is_available():
        print("choose to use gpu...")
        device = torch.device("cuda:0")
        torch.set_num_threads(all_args.n_training_threads)
        if all_args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        print("choose to use cpu...")
        device = torch.device("cpu")
        torch.set_num_threads(all_args.n_training_threads)
    run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[
                       0] + "/results")

    v2_map_name = f'{all_args.map_name}_{all_args.v2_n_units}_vs_{all_args.v2_n_enemies}'
    print(f'v2_map_name = {v2_map_name} v2_n_units = {all_args.v2_n_units} v2_n_enemies = {all_args.v2_n_enemies}')

    name_list = [all_args.env_name, v2_map_name, all_args.algorithm_name, all_args.group_name,
                 all_args.experiment_name]
    for name in name_list:
        run_dir /= name
        if not run_dir.exists():
            os.makedirs(str(run_dir))


    if all_args.use_wandb:
        run = wandb.init(config=all_args,
                         project=all_args.env_name,
                         entity=all_args.user_name,
                         notes=socket.gethostname(),
                         name=str(all_args.algorithm_name) + "_" +
                              str(all_args.experiment_name) +
                              "_seed" + str(all_args.seed),
                         group=all_args.map_name,
                         dir=str(run_dir),
                         job_type="training",
                         reinit=True)
    else:
        if not run_dir.exists():
            curr_run = 'run1'
        else:
            exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if
                             str(folder.name).startswith('run')]
            if len(exst_run_nums) == 0:
                curr_run = 'run1'
            else:
                curr_run = 'run%i' % (max(exst_run_nums) + 1)
        run_dir = run_dir / curr_run
        if not run_dir.exists():
            os.makedirs(str(run_dir))

    setproctitle.setproctitle(
        str(all_args.algorithm_name) + "-" + str(all_args.env_name) + "-" + str(all_args.experiment_name) + "@" + str(
            all_args.user_name))

    # seed
    torch.manual_seed(all_args.seed)
    torch.cuda.manual_seed_all(all_args.seed)
    np.random.seed(all_args.seed)
    random.seed(all_args.seed)
    # env
    envs = make_train_env(all_args)
    print(f'args.map_name = {all_args.map_name}')
    eval_envs = make_eval_env(all_args) if all_args.use_eval else None
    # if all_args.map_name.startswith('10'):
    #     # num_agents = get_map_params(all_args.map_name)["n_agents"]
    #     num_agents = 5
    num_agents = all_args.v2_n_units
    all_args.n_agents = num_agents
    env_info = envs.get_env_info()

    print('env_info = {}'.format(env_info))
    all_args.n_actions = env_info["n_actions"]
    all_args.obs_shape = env_info["obs_shape"]
    all_args.state_shape = env_info["state_shape"]
    all_args.num_actions = env_info["n_actions"]
    all_args.episode_length = env_info['episode_limit']
    all_args.num_agents = num_agents
    config = {
        "all_args": all_args,
        "envs": envs,
        "eval_envs": eval_envs,
        "num_agents": num_agents,
        "device": device,
        "run_dir": run_dir
    }

    # run experiments
    if all_args.share_policy:
        from onpolicy.runner.shared.smac_runner_new import SMACRunner as Runner
    else:
        from onpolicy.runner.separated.smac_runner_new import SMACRunner as Runner

    runner = Runner(config)
    runner.run()

    # post process
    envs.close()
    if all_args.use_eval and eval_envs is not envs:
        eval_envs.close()

    if all_args.use_wandb:
        run.finish()
    else:
        runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json'))
        runner.writter.close()


if __name__ == "__main__":
    main(sys.argv[1:])
