import gym
import env.gym_fortattack as gym_fortattack
from gym import spaces
import numpy as np

# from multiagent.multi_discrete import MultiDiscrete
from env.spaces import Box, MASpace, MAEnvSpec
from gym.spaces import MultiDiscrete
import time, os


def make_fortattack_env(
    numGuards,
    numAttackers,
    num_steps,
    num_shots,
    max_rot,
    random_starting_rot,
    return_image=False,
    rot_rew_param=0.01,
    dist_rew_param=0.01,
    default_spawn_pos=None,
):

    scenario = gym.make(
        "fortattack-v1",
        numGuards=numGuards,
        numAttackers=numAttackers,
        max_rot=max_rot,
        random_starting_rot=random_starting_rot,
        rot_rew_param=rot_rew_param,
        dist_rew_param=dist_rew_param,
        default_spawn_pos=default_spawn_pos,
        rectangle=False,
    )
    scenario.num_shots = num_shots
    scenario.add_num_bullets(num_shots, True)
    # create world
    world = scenario.world
    world.max_time_steps = num_steps
    # create multiagent environment
    env = FortAttackGlobalEnv(
        world=world,
        num_attackers=numAttackers,
        num_guards=numGuards,
        num_shots=num_shots,
        reset_callback=scenario.reset_world,
        reward_callback=scenario.reward,
        observation_callback=scenario.observation,
        info_callback=scenario.info,
        concept_callback=scenario.concept,
        return_image=return_image,
        attacker_can_fire=True,
    )
    return env


class FortAttackGlobalEnv(gym.Env):
    metadata = {"render.modes": ["human", "rgb_array"]}

    def terminate(self):
        pass

    def __init__(
        self,
        world,
        num_attackers,
        num_guards,
        num_shots,
        reset_callback=None,
        reward_callback=None,
        observation_callback=None,
        info_callback=None,
        concept_callback=None,
        done_callback=None,
        shared_viewer=True,
        return_image=False,
        attacker_can_fire=True,
        default_spawn_pos="random",
    ):
        self.num_attackers = num_attackers
        self.num_guards = num_guards

        self.world = world
        self.spawn_pos = default_spawn_pos
        self.agents = self.world.policy_agents
        # set required vectorized gym env property
        self.n = len(world.policy_agents)
        # scenario callbacks
        self.reset_callback = reset_callback
        self.reward_callback = reward_callback
        self.observation_callback = observation_callback
        self.info_callback = info_callback
        self.concept_callback = concept_callback
        self.done_callback = done_callback
        self.return_image = return_image
        # environment parameters
        self.discrete_action_space = True
        # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
        self.discrete_action_input = True  # False

        # if true, every agent has the same reward
        self.shared_reward = (
            world.collaborative if hasattr(world, "collaborative") else False
        )

        # configure spaces
        self.action_space = []
        self.observation_space = []
        obs_shapes = []
        self.agent_num = len(self.agents)
        attacker_number = 1
        guard_number = 1
        for i, agent in enumerate(self.agents):
            agent.max_bullets = num_shots
            if agent.attacker:
                agent.num = attacker_number
                attacker_number += 1
            else:
                agent.num = guard_number
                guard_number += 1
            if not attacker_can_fire and agent.attacker:
                agent.max_bullets = 0
            total_action_space = []
            # physical action space
            move_action_space = spaces.Discrete(3)
            rot_or_shoot_action_space = spaces.Discrete(4)
            if agent.movable:
                total_action_space.append(move_action_space)
                total_action_space.append(rot_or_shoot_action_space)

            # communication action space
            if self.discrete_action_space:
                c_action_space = spaces.Discrete(world.dim_c)
            else:
                c_action_space = spaces.Box(
                    low=0.0, high=1.0, shape=(world.dim_c,), dtype=np.float32
                )

            if not agent.silent:
                total_action_space.append(c_action_space)
            # total action space
            if len(total_action_space) > 1:
                # all action spaces are discrete, so simplify to MultiDiscrete action space
                if all(
                    [
                        isinstance(act_space, spaces.Discrete)
                        for act_space in total_action_space
                    ]
                ):
                    # assert 0, "MultiDiscrete action space is not supported"
                    act_space = MultiDiscrete(
                        [act_space.n for act_space in total_action_space]
                    )
                else:
                    act_space = spaces.Tuple(total_action_space)
                self.action_space.append(act_space)
            else:
                self.action_space.append(total_action_space[0])

            obs_dim = len(observation_callback(agent, self.world))
            self.observation_space.append(
                spaces.Box(
                    low=-np.inf,
                    high=+np.inf,
                    shape=(self.n, obs_dim),
                    dtype=np.float32,
                )
            )
            agent.action.c = np.zeros(self.world.dim_c)

        self.action_range = [0.0, 1.0]
        # rendering
        self.shared_viewer = shared_viewer
        # print(gym_fortattack.__file__)
        self.prevShot, self.shot = False, False  # used for rendering
        self.hasnotimported = True

    def step(self, action_n):
        # print('first step')
        ## Sequence:
        ## First set the actions for all agents
        ## 0- For each bullet, check how many agents are killed
        ##    Remove those agents and the bullet
        ## 1- Perform movement
        ## 2- Shoot? -> generate a new bullet
        obs_n = []
        ground_truth_n = []
        reward_n = []
        done_n = []
        info_n = []
        concepts_n = []
        self.agents = self.world.policy_agents
        # set action for each agent
        # action originally had 5 values - accel, +forcex, -forcex, +forcey, -forcey
        # I have some doubt on how term acceleration is used in computation - check _set_action()
        # I added 2 extra components, rotation and shoot
        # print('action_n')
        # print(action_n)
        for i, agent in enumerate(self.agents):
            action = action_n[i]
            # print(action)
            # action = np.array(action_n[i]).reshape(((self.world.dim_p-1)*2+3,)) ##
            self._set_action(
                action, agent, self.action_space[i]
            )  # sets the actions in the agent object

        # advance world state
        ## actions are already set in the objects, so we can simply pass step without any argument
        self.world.step()  # world is the fortattack-v0 environment, step function is in core.py file

        # record observation for each agent
        for i, agent in enumerate(self.agents):  ##
            obs_n.append(self._get_obs(agent))
            ground_truth_n.append(self._get_gt_obs(agent))
            reward_n.append(self._get_reward(agent))
            done_n.append(self._get_done_agent(agent))
            info_n.append(self._get_info(agent))

        for i, agent in enumerate(self.agents):
            concepts_n.append(self._get_concepts(agent))

        ## implement single done reflecting game state
        # done = self._get_done()
        distances = self._get_distances()
        # all agents get total reward in cooperative case
        reward = np.sum(reward_n)
        if self.shared_reward:
            reward_n = [reward] * self.n
        self.world.time_step += 1
        obs_n = np.array(obs_n)
        ground_truth_n = np.array(ground_truth_n)
        return obs_n, reward_n, done_n, info_n, ground_truth_n, concepts_n

    def reset(self):
        # reset world
        self.reset_callback(self.spawn_pos)
        # record observations for each agent
        obs_n = []
        self.agents = self.world.policy_agents
        for agent in self.agents:
            obs_n.append(self._get_obs(agent))
        obs_n = np.array(obs_n)
        return obs_n

    # get info used for benchmarking
    def _get_info(self, agent):
        if self.info_callback is None:
            return {}
        return self.info_callback(agent, self.world)

    # get concepts
    def _get_concepts(self, agent):
        if self.info_callback is None:
            return {}
        return self.concept_callback(agent, self.world)

    # get observation for a particular agent
    def _get_obs(self, agent):
        if self.observation_callback is None:
            return np.zeros(0)
        return self.observation_callback(agent, self.world)

    def _get_gt_obs(self, agent):
        if self.observation_callback is None:
            return np.zeros(0)
        return self.observation_callback(agent, self.world, gt=True)

    # get done for the whole environment
    # unused right now -- agents are allowed to go beyond the viewing screen
    # update:: switching get done to get done per a particular agent
    def _get_done(self):
        # done if any attacker reached landmark, attackers win
        th = self.world.fortDim
        for attacker in self.world.alive_attackers:
            dist = np.sqrt(np.sum(np.square(attacker.state.p_pos - self.world.doorLoc)))
            if dist < th:
                # print('attacker reached fort')
                self.world.gameResult[2] = 1
                return True

        # done if all attackers are dead, guards win
        if self.world.numAliveAttackers == 0:
            # print('all attackers dead')
            self.world.gameResult[0] = 1
            return True

        # done if max number of time steps over, guards win
        elif self.world.time_step == self.world.max_time_steps - 1:
            # print('max number of time steps')
            self.world.gameResult[1] = 1
            return True

        # otherwise not done
        return False

    def _get_done_agent(self, agent):
        # done if any attacker reached landmark, attackers win
        th = self.world.fortDim
        for attacker in self.world.alive_attackers:
            dist = np.sqrt(np.sum(np.square(attacker.state.p_pos - self.world.doorLoc)))
            if dist < th:
                # print('attacker reached fort')
                self.world.gameResult[2] = 1
                return True

        # done if all attackers are dead, guards win
        if self.world.numAliveAttackers == 0:
            # print('all attackers dead')
            self.world.gameResult[0] = 1
            return True
        elif not agent.alive:
            return True

        # done if max number of time steps over, guards win
        elif self.world.time_step == self.world.max_time_steps - 1:
            # print('max number of time steps')
            self.world.gameResult[1] = 1
            return True

        # otherwise not done
        return False

    def _get_distances(self):
        # done if any attacker reached landmark, attackers win
        th = self.world.fortDim
        distances = []
        for attacker in self.world.attackers:
            if attacker.alive:
                dist = np.sqrt(
                    np.sum(np.square(attacker.state.p_pos - self.world.doorLoc))
                )
            else:
                dist = -1
            distances.append(dist)
        # otherwise not done
        return distances

    # get reward for a particular agent
    def _get_reward(self, agent):
        if self.reward_callback is None:
            return 0.0
        return self.reward_callback(agent)

    # set env action for a particular agent
    def _set_action(self, action, agent, action_space, time=None):

        agent.action.u = np.zeros(self.world.dim_p)
        agent.action.c = np.zeros(self.world.dim_c)
        # process action
        # if isinstance(action_space, MultiDiscrete):
        #     act = []
        #     size = action_space.high - action_space.low + 1
        #     index = 0
        #     for s in size:
        #         act.append(action[index:(index+s)])
        #         index += s
        #     action = act
        # else:
        action = [action]

        if agent.movable:
            # print('self.discrete_action_input', self.discrete_action_input) # True
            # physical action
            if self.discrete_action_input:
                agent.action.u = np.zeros(
                    self.world.dim_p
                )  ## We'll use this now for Graph NN
                # process discrete action
                ## if action[0] == 0, then do nothing

                if self.multi_discrete and (
                    not agent.attacker or not self.using_hard_coded_paths
                ):
                    moving_action = action[0][0]
                    shoot_rotation_action = action[0][1]
                    if moving_action == 1:
                        agent.action.u[0] = np.cos(agent.state.p_ang)
                        agent.action.u[1] = np.sin(agent.state.p_ang)
                    if moving_action == 2:
                        agent.action.u[0] = -np.cos(agent.state.p_ang)
                        agent.action.u[1] = -np.sin(agent.state.p_ang)

                    agent.action.shoot = False
                    if shoot_rotation_action == 1:
                        agent.action.u[2] = +agent.max_rot
                    elif shoot_rotation_action == 2:
                        agent.action.u[2] = -agent.max_rot
                    elif shoot_rotation_action == 3:
                        agent.action.shoot = agent.state.bullets_left > 0

                else:
                    if action[0] == 1:
                        agent.action.u[0] = +1.0
                    if action[0] == 2:
                        agent.action.u[0] = -1.0
                    if action[0] == 3:
                        agent.action.u[1] = +1.0
                    if action[0] == 4:
                        agent.action.u[1] = -1.0
                    if action[0] == 5:
                        agent.action.u[2] = +agent.max_rot
                    if action[0] == 6:
                        agent.action.u[2] = -agent.max_rot

                    agent.action.shoot = (
                        True
                        if action[0] == 7 and agent.state.bullets_left > 0
                        else False
                    )
                if agent.action.shoot:
                    agent.state.bullets_left += -1

            else:
                if self.discrete_action_space:  ## this was begin used in PR2 Paper
                    # print('action', action)
                    agent.action.u[0] += (
                        action[0][1] - action[0][2]
                    )  ## each is 0 to 1, so total is -1 to 1
                    agent.action.u[1] += action[0][3] - action[0][4]  ## same as above

                    ## simple shooting action
                    agent.action.shoot = (
                        True
                        if action[0][6] > 0.5 and agent.state.bullets_left >= 0
                        else False
                    )  # a number greater than 0.5 would mean shoot

                    if agent.action.shoot:
                        agent.state.bullets_left += -1

                    ## simple rotation model
                    agent.action.u[2] = 2 * (action[0][5] - 0.5) * agent.max_rot

                else:
                    agent.action.u = action[0]
            sensitivity = 5.0  # default if no value specified for accel
            if agent.accel is not None:
                sensitivity = agent.accel
            agent.action.u[:2] *= sensitivity

            ## remove used actions
            action = action[1:]

        if not agent.silent:
            # communication action
            if self.discrete_action_input:
                agent.action.c = np.zeros(self.world.dim_c)
                agent.action.c[action[0]] = 1.0
            else:
                agent.action.c = action[0]
            action = action[1:]

        # make sure we used all elements of action
        assert len(action) == 0
