from ray.tune.registry import register_env
from env import FortAttack
from ray.rllib.models import ModelCatalog
from gym.spaces import Dict, Discrete, Tuple, MultiDiscrete
from ray import tune
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.policy.policy import PolicySpec
from concept_config import concept_config
import copy

from utils import possibly_gridsearch


def scenario_generator(args, CustomPolicy):
    if args.scenario == "fort_attack":

        # Fort Attack Environment
        def env_creater(mpe_args):
            return FortAttack(**mpe_args)

        if args.rectangle:
            hard_coded_paths = [
                [
                    [0.8, -0.5 * 1.7],
                    [0.9, 0.5 * 1.7],
                    [0.5, 0.8 * 1.7],
                    [0, 1 * 1.7],
                ],
                [
                    [-0.8, -0.5 * 1.7],
                    [-0.9, 0.5 * 1.7],
                    [-0.5, 0.8 * 1.7],
                    [0, 1 * 1.7],
                ],
            ]

        else:
            hard_coded_paths = [
                [[0.8, -0.5], [0.9, 0.5], [0.5, 0.8], [0, 1]],
                [[-0.8, -0.5], [-0.9, 0.5], [-0.5, 0.8], [0, 1]],
            ]
        env_config = {
            "numGuards": args.num_guards,
            "numAttackers": args.num_attackers,
            "num_steps": args.num_steps,
            "num_shots": args.num_shots,
            "max_rot": 0.35,
            "random_starting_rot": args.random_starting_rot,
            "return_image": args.return_image,
            "attacker_can_fire": args.attacker_can_fire,
            "hard_coded_paths": hard_coded_paths,
            "use_hard_coded_paths": args.use_hard_coded_paths,
            "render_resolution": 96,
            "use_pygame": args.use_pygame,
            "name": "fort_attack",
            "rot_rew_param": 0.01,
            "dist_rew_param": 0.5, 
            "discrete_actions": True,
            "multi_discrete": args.multi_discrete,
            "default_spawn_pos": "random",
            "rectangle": args.rectangle,
            "current_concepts": args.concepts,
        }
        env = env_creater(env_config)

        
        if args.return_image:
            model_name = "CnnRecurrentRewModel"  # "CnnRecurrentModel"
        else:
            model_name = "RNNRewModel"
        register_env("custom_env", env_creater)
        obs_space_dict_adversary = env.observation_space_dict[0]
        action_space_dict_adversary = env.action_space_dict[0]
        obs_space_dict_guard = env.observation_space_dict[0]
        action_space_dict_guard = env.action_space_dict[0]
        adversaryModelDic = {
            "custom_model": model_name,
            "custom_model_config": {
                "num_agents": args.num_attackers,
                "num_opp_agents": args.num_guards,
                "input_size": 13,
                "embed_dim": possibly_gridsearch(args.embed_dim),
                "n_heads": 1,
                "fc1_layers": 2,  # tune.grid_search([2, 3]),
                "fc2_layers": 2,  # tune.grid_search([2, 3]),
                "fc3_layers": 2,  # tune.grid_search([1, 3]),
                "policy_layers": 1,
                "conceptdim": possibly_gridsearch(args.conceptdim),
                "bottleneck": 0,
                "dropout": 0.5,
                "is_guard": False,
                "include_concepts": args.include_concepts,
                "concept_configs": concept_config(args),
                "include_whitening": args.include_whitening,
                "affine_whitening": True,  # tune.grid_search([True, False]),
                "T_whitening": possibly_gridsearch(
                    args.t_whitening
                ),  # tune.grid_search([2, 4]),
                "use_reward": args.use_reward,
            },
            "max_seq_len": 50,
        }
        goodModelDic = copy.deepcopy(adversaryModelDic)
        goodModelDic["custom_model_config"]["num_agents"] = args.num_guards
        goodModelDic["custom_model_config"]["num_opp_agents"] = args.num_attackers
        goodModelDic["custom_model_config"]["is_guard"] = True
        adversaryModelDic["custom_model_config"]["conceptdim"] = 32
        adversaryModelDic["custom_model_config"]["embed_dim"] = 128
        adversaryModelDic["custom_model_config"]["T_whitening"] = 2
        adversary_model_dict = {
            "obs_space_dict": env.observation_space_dict[0],
            "act_space_dict": env.action_space_dict[0],
            "model": adversaryModelDic,
        }
        good_model_dict = {
            "obs_space_dict": env.observation_space_dict[0],
            "act_space_dict": env.action_space_dict[0],
            "model": goodModelDic,
        }
        obs_space_dict = env.observation_space_dict[0]
        action_space_dict = env.action_space_dict[0]

        policies = {
            "adversary_policy": (
                CustomPolicy,
                obs_space_dict_adversary,
                action_space_dict_adversary,
                adversary_model_dict,
            ),
            "good_policy": (
                CustomPolicy,
                obs_space_dict_guard,
                action_space_dict_guard,
                good_model_dict,
            ),
        }

        def policy_mapping_fn(agent_id, episode, **kwargs):
            if agent_id < args.num_guards:
                return "good_policy"
            else:
                return "adversary_policy"

    return env_config, policy_mapping_fn, policies