import math
import warnings

import magent
import numpy as np
from gym.spaces import Box, Discrete
from gym.utils import EzPickle

from pettingzoo import AECEnv
from pettingzoo.magent.render import Renderer
from pettingzoo.utils import agent_selector
from pettingzoo.utils.conversions import from_parallel_wrapper, parallel_wrapper_fn

from .magent_env import magent_parallel_env, make_env

default_map_size = 45
max_cycles_default = 1000
KILL_REWARD = 5
minimap_mode_default = False
# XR_comment: may need to increase attack_penalty(e.g. -0.01 like DGN) because agents are now reluctant to attack
default_reward_args = dict(
    step_reward=-0.005,
    dead_penalty=-0.1,
    attack_penalty=-0.1,
    attack_opponent_reward=0.2,
)


def parallel_env(
    map_size=default_map_size,
    max_cycles=max_cycles_default,
    minimap_mode=minimap_mode_default,
    extra_features=False,
    **reward_args
):
    env_reward_args = dict(**default_reward_args)
    env_reward_args.update(reward_args)
    return _parallel_env(
        map_size, minimap_mode, env_reward_args, max_cycles, extra_features
    )


def raw_env(
    map_size=default_map_size,
    max_cycles=max_cycles_default,
    minimap_mode=minimap_mode_default,
    extra_features=False,
    **reward_args
):
    return from_parallel_wrapper(
        parallel_env(map_size, max_cycles, minimap_mode, extra_features, **reward_args)
    )


env = make_env(raw_env)


def get_config(
    map_size,
    minimap_mode,
    step_reward,
    dead_penalty,
    attack_penalty,
    attack_opponent_reward,
):
    gw = magent.gridworld
    cfg = gw.Config()

    cfg.set({"map_width": map_size, "map_height": map_size})
    cfg.set({"minimap_mode": minimap_mode})
    cfg.set({"embedding_size": 10})

    options = {
        "width": 1,
        "length": 1,
        "hp": 10,
        "speed": 2,
        "view_range": gw.CircleRange(3),
        "attack_range": gw.CircleRange(1.5),
        "damage": 2,
        "kill_reward": KILL_REWARD,
        "step_recover": 0.1,
        "step_reward": step_reward,
        "dead_penalty": dead_penalty,
        "attack_penalty": attack_penalty,
    }
    small = cfg.register_agent_type("small", options)

    g0 = cfg.add_group(small)
    g1 = cfg.add_group(small)

    a = gw.AgentSymbol(g0, index="any")
    b = gw.AgentSymbol(g1, index="any")

    # reward shaping to encourage attack
    cfg.add_reward_rule(
        gw.Event(a, "attack", b), receiver=a, value=attack_opponent_reward
    )
    cfg.add_reward_rule(
        gw.Event(b, "attack", a), receiver=b, value=attack_opponent_reward
    )

    return cfg


class _parallel_env(magent_parallel_env, EzPickle):
    metadata = {"render.modes": ["human", "rgb_array"], "name": "battle_v3"}

    def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features):
        EzPickle.__init__(
            self, map_size, minimap_mode, reward_args, max_cycles, extra_features
        )
        assert map_size >= 12, "size of map must be at least 12"
        env = magent.GridWorld(
            get_config(map_size, minimap_mode, **reward_args), map_size=map_size
        )
        self.leftID = 0
        self.rightID = 1
        reward_vals = np.array([KILL_REWARD] + list(reward_args.values()))
        reward_range = [
            np.minimum(reward_vals, 0).sum(),
            np.maximum(reward_vals, 0).sum(),
        ]
        names = ["red", "blue"]
        super().__init__(
            env,
            env.get_handles(),
            names,
            map_size,
            max_cycles,
            reward_range,
            minimap_mode,
            extra_features,
        )

    def generate_map(self):
        env, map_size, handles = self.env, self.map_size, self.handles
        """ generate a map, which consists of two squares of agents"""
        width = height = map_size
        init_num = map_size * map_size * 0.04
        gap = 3

        self.leftID, self.rightID = self.rightID, self.leftID

        # left
        n = init_num
        side = int(math.sqrt(n)) * 2
        pos = []
        for x in range(width // 2 - gap - side, width // 2 - gap - side + side, 2):
            for y in range((height - side) // 2, (height - side) // 2 + side, 2):
                if 0 < x < width - 1 and 0 < y < height - 1:
                    pos.append([x, y, 0])
        team1_size = len(pos)
        env.add_agents(handles[self.leftID], method="custom", pos=pos)

        # right
        n = init_num
        side = int(math.sqrt(n)) * 2
        pos = []
        for x in range(width // 2 + gap, width // 2 + gap + side, 2):
            for y in range((height - side) // 2, (height - side) // 2 + side, 2):
                if 0 < x < width - 1 and 0 < y < height - 1:
                    pos.append([x, y, 0])

        pos = pos[:team1_size]
        env.add_agents(handles[self.rightID], method="custom", pos=pos)
