import gym
import env.gym_fortattack as gym_fortattack
from gym import spaces
import numpy as np

# from multiagent.multi_discrete import MultiDiscrete
from env.spaces import Box, MASpace, MAEnvSpec
from gym.spaces import MultiDiscrete
import time, os

import pygame

# os.environ["SDL_VIDEODRIVER"] = "dummy"


os.environ["SDL_AUDIODRIVER"] = "dummy"
# environment for all agents in the multiagent world
# currently code assumes that no agents will be created/destroyed at runtime!


def make_fortattack_env(
    numGuards,
    numAttackers,
    num_steps,
    num_shots,
    max_rot,
    random_starting_rot,
    benchmark=False,
    return_image=False,
    attacker_can_fire=True,
    render_resolution=700,
    hard_coded_paths=None,
    use_hard_coded_paths=False,
    use_pygame=True,
    rot_rew_param=0.01,
    dist_rew_param=0.01,
    discrete_actions=True,
    multi_discrete=False,
    default_spawn_pos=None,
    rectangle=False,
):

    scenario = gym.make(
        "fortattack-v1",
        numGuards=numGuards,
        numAttackers=numAttackers,
        max_rot=max_rot,
        random_starting_rot=random_starting_rot,
        rot_rew_param=rot_rew_param,
        dist_rew_param=dist_rew_param,
        default_spawn_pos=default_spawn_pos,
        rectangle=rectangle,
    )
    scenario.num_shots = num_shots
    scenario.add_num_bullets(num_shots, attacker_can_fire)
    # create world
    world = scenario.world
    world.max_time_steps = num_steps
    # create multiagent environment
    env = FortAttackGlobalEnv(
        world=world,
        num_attackers=numAttackers,
        num_guards=numGuards,
        num_shots=num_shots,
        reset_callback=scenario.reset_world,
        reward_callback=scenario.reward,
        observation_callback=scenario.observation,
        info_callback=scenario.info,
        concept_callback=scenario.concept,
        return_image=return_image,
        attacker_can_fire=attacker_can_fire,
        render_resolution=render_resolution,
        discrete_actions=discrete_actions,
        multi_discrete=multi_discrete,
        rectangle=rectangle,
    )
    return env


# environment for all agents in the multiagent world
# currently code assumes that no agents will be created/destroyed at runtime!
class FortAttackGlobalEnv(gym.Env):
    metadata = {"render.modes": ["human", "rgb_array"]}

    def terminate(self):
        pass

    def __init__(
        self,
        world,
        num_attackers,
        num_guards,
        num_shots,
        reset_callback=None,
        reward_callback=None,
        observation_callback=None,
        info_callback=None,
        concept_callback=None,
        done_callback=None,
        shared_viewer=True,
        return_image=False,
        attacker_can_fire=True,
        render_resolution=700,
        discrete_actions=True,
        multi_discrete=False,
        default_spawn_pos="random",
        rectangle=False,
    ):
        self.num_attackers = num_attackers
        self.num_guards = num_guards

        self.render_resolution = render_resolution
        self.multi_discrete = multi_discrete
        self.ob_rms = None
        self.world = world
        self.spawn_pos = default_spawn_pos
        self.rectangle = rectangle
        self.using_hard_coded_paths = False
        self.agents = self.world.policy_agents
        # set required vectorized gym env property
        self.n = len(world.policy_agents)
        # scenario callbacks
        self.reset_callback = reset_callback
        self.reward_callback = reward_callback
        self.observation_callback = observation_callback
        self.info_callback = info_callback
        self.concept_callback = concept_callback
        self.done_callback = done_callback
        self.return_image = return_image
        # environment parameters
        if discrete_actions:
            self.discrete_action_space = True
            # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
            self.discrete_action_input = True  # False
            # if true, even the action is continuous, action will be performed discretely
        else:
            self.discrete_action_space = False
            # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
            self.discrete_action_input = False

        # if true, every agent has the same reward
        self.shared_reward = (
            world.collaborative if hasattr(world, "collaborative") else False
        )

        # configure spaces
        self.action_space = []
        self.observation_space = []
        obs_shapes = []
        self.agent_num = len(self.agents)
        attacker_number = 1
        guard_number = 1
        for i, agent in enumerate(self.agents):
            agent.max_bullets = num_shots
            if agent.attacker:
                agent.num = attacker_number
                attacker_number += 1
            else:
                agent.num = guard_number
                guard_number += 1
            if not attacker_can_fire and agent.attacker:
                agent.max_bullets = 0
            total_action_space = []
            # physical action space
            if self.discrete_action_space:
                if self.multi_discrete:
                    move_action_space = spaces.Discrete(3)
                    rot_or_shoot_action_space = spaces.Discrete(4)
                    if agent.movable:
                        total_action_space.append(move_action_space)
                        total_action_space.append(rot_or_shoot_action_space)
                else:
                    u_action_space = spaces.Discrete((world.dim_p) * 2 + 2)  ##
                    if agent.movable:
                        total_action_space.append(u_action_space)
            else:
                u_action_space = spaces.Box(
                    low=-agent.u_range,
                    high=+agent.u_range,
                    shape=(world.dim_p + 1,),  # +1 for ability to shoot
                    dtype=np.float32,
                )
                if agent.movable:
                    total_action_space.append(u_action_space)
            # communication action space
            if self.discrete_action_space:
                c_action_space = spaces.Discrete(world.dim_c)
            else:
                c_action_space = spaces.Box(
                    low=0.0, high=1.0, shape=(world.dim_c,), dtype=np.float32
                )

            if not agent.silent:
                total_action_space.append(c_action_space)
            # total action space
            if len(total_action_space) > 1:
                # all action spaces are discrete, so simplify to MultiDiscrete action space
                if all(
                    [
                        isinstance(act_space, spaces.Discrete)
                        for act_space in total_action_space
                    ]
                ):
                    # assert 0, "MultiDiscrete action space is not supported"
                    act_space = MultiDiscrete(
                        [act_space.n for act_space in total_action_space]
                    )
                else:
                    act_space = spaces.Tuple(total_action_space)
                self.action_space.append(act_space)
            else:
                self.action_space.append(total_action_space[0])
            # observation space
            if self.return_image:
                # since we're using images, obs_dim no longer applies
                obs_shapes.append((3, 96, 96))
                self.observation_space.append(
                    spaces.Box(
                        low=-np.inf, high=+np.inf, shape=(3, 96, 96), dtype=np.float32
                    )
                )
            else:
                obs_dim = len(observation_callback(agent, self.world))
                self.observation_space.append(
                    spaces.Box(
                        low=-np.inf,
                        high=+np.inf,
                        shape=(self.n, obs_dim),
                        dtype=np.float32,
                    )
                )
            agent.action.c = np.zeros(self.world.dim_c)
        # simpified for non-comm game

        # self.action_spaces = MASpace(tuple(Box(low=-1., high=1., shape=(1,)) for _ in range(self.agent_num)))
        # self.observation_spaces = MASpace(tuple(Discrete(1) for _ in range(self.agent_num)))

        # action has 8 values:
        # nothing, +forcex, -forcex, +forcey, -forcey, +rot, -rot, shoot
        # self.action_spaces = MASpace(
        #     tuple(
        #         Box(low=0.0, high=1.0, shape=((world.dim_p) * 2 + 2,))
        #         for _ in range(self.agent_num)
        #     )
        # )  ##
        # self.observation_spaces = MASpace(
        #     tuple(
        #         Box(low=-np.inf, high=+np.inf, shape=obs_shape)
        #         for obs_shape in obs_shapes
        #     )
        # )
        # self.env_specs = MAEnvSpec(self.observation_spaces, self.action_spaces)

        self.action_range = [0.0, 1.0]
        # rendering
        self.shared_viewer = shared_viewer
        if self.shared_viewer:
            self.viewers = [None]
        else:
            self.viewers = [None] * self.n
        if self.rectangle:
            self.screen = pygame.Surface(
                (self.render_resolution, int(2.2246 * self.render_resolution))
            )
            self.visualizing_screen = pygame.Surface((700, 1557))
        else:
            self.screen = pygame.Surface(
                (self.render_resolution, self.render_resolution)
            )
            self.visualizing_screen = pygame.Surface((700, 700))
        # print(gym_fortattack.__file__)
        self.prevShot, self.shot = False, False  # used for rendering
        self._reset_render()
        self.hasnotimported = True

    def step(self, action_n):
        # print('first step')
        ## Sequence:
        ## First set the actions for all agents
        ## 0- For each bullet, check how many agents are killed
        ##    Remove those agents and the bullet
        ## 1- Perform movement
        ## 2- Shoot? -> generate a new bullet
        obs_n = []
        ground_truth_n = []
        reward_n = []
        done_n = []
        info_n = []
        concepts_n = []
        self.agents = self.world.policy_agents
        # set action for each agent
        # action originally had 5 values - accel, +forcex, -forcex, +forcey, -forcey
        # I have some doubt on how term acceleration is used in computation - check _set_action()
        # I added 2 extra components, rotation and shoot
        # print('action_n')
        # print(action_n)
        for i, agent in enumerate(self.agents):
            action = action_n[i]
            # print(action)
            # action = np.array(action_n[i]).reshape(((self.world.dim_p-1)*2+3,)) ##
            self._set_action(
                action, agent, self.action_space[i]
            )  # sets the actions in the agent object

        # advance world state
        ## actions are already set in the objects, so we can simply pass step without any argument
        self.world.step()  # world is the fortattack-v0 environment, step function is in core.py file

        # record observation for each agent
        for i, agent in enumerate(self.agents):  ##
            obs_n.append(self._get_obs(agent))
            ground_truth_n.append(self._get_gt_obs(agent))
            reward_n.append(self._get_reward(agent))
            done_n.append(self._get_done_agent(agent))
            info_n.append(self._get_info(agent))

        for i, agent in enumerate(self.agents):
            concepts_n.append(self._get_concepts(agent))

        ## implement single done reflecting game state
        # done = self._get_done()
        distances = self._get_distances()
        # all agents get total reward in cooperative case
        reward = np.sum(reward_n)
        if self.shared_reward:
            reward_n = [reward] * self.n
        self.world.time_step += 1
        obs_n = np.array(obs_n)
        ground_truth_n = np.array(ground_truth_n)
        return obs_n, reward_n, done_n, info_n, ground_truth_n, concepts_n
    
    def get_concepts(self):
        concepts_n = []
        for i, agent in enumerate(self.agents):
            concepts_n.append(self._get_concepts(agent))
        return concepts_n
        

    def reset(self):
        # reset world
        self.reset_callback(self.spawn_pos)
        # reset renderer
        self._reset_render()
        # record observations for each agent
        obs_n = []
        self.agents = self.world.policy_agents
        for agent in self.agents:
            obs_n.append(self._get_obs(agent))
        obs_n = np.array(obs_n)
        return obs_n

    # get info used for benchmarking
    def _get_info(self, agent):
        if self.info_callback is None:
            return {}
        return self.info_callback(agent, self.world)

    # get concepts
    def _get_concepts(self, agent):
        if self.info_callback is None:
            return {}
        return self.concept_callback(agent, self.world)

    # get observation for a particular agent
    def _get_obs(self, agent):
        if self.observation_callback is None:
            return np.zeros(0)
        return self.observation_callback(agent, self.world)

    def _get_gt_obs(self, agent):
        if self.observation_callback is None:
            return np.zeros(0)
        return self.observation_callback(agent, self.world, gt=True)

    # get done for the whole environment
    # unused right now -- agents are allowed to go beyond the viewing screen
    # update:: switching get done to get done per a particular agent
    def _get_done(self):
        # done if any attacker reached landmark, attackers win
        th = self.world.fortDim
        for attacker in self.world.alive_attackers:
            dist = np.sqrt(np.sum(np.square(attacker.state.p_pos - self.world.doorLoc)))
            if dist < th:
                # print('attacker reached fort')
                self.world.gameResult[2] = 1
                return True

        # done if all attackers are dead, guards win
        if self.world.numAliveAttackers == 0:
            # print('all attackers dead')
            self.world.gameResult[0] = 1
            return True

        # done if max number of time steps over, guards win
        elif self.world.time_step == self.world.max_time_steps - 1:
            # print('max number of time steps')
            self.world.gameResult[1] = 1
            return True

        # otherwise not done
        return False

    def _get_done_agent(self, agent):
        # done if any attacker reached landmark, attackers win
        th = self.world.fortDim
        for attacker in self.world.alive_attackers:
            dist = np.sqrt(np.sum(np.square(attacker.state.p_pos - self.world.doorLoc)))
            if dist < th:
                # print('attacker reached fort')
                self.world.gameResult[2] = 1
                return True

        # done if all attackers are dead, guards win
        if self.world.numAliveAttackers == 0:
            # print('all attackers dead')
            self.world.gameResult[0] = 1
            return True
        elif not agent.alive:
            return True

        # done if max number of time steps over, guards win
        elif self.world.time_step == self.world.max_time_steps - 1:
            # print('max number of time steps')
            self.world.gameResult[1] = 1
            return True

        # otherwise not done
        return False

    def _get_distances(self):
        # done if any attacker reached landmark, attackers win
        th = self.world.fortDim
        distances = []
        for attacker in self.world.attackers:
            if attacker.alive:
                dist = np.sqrt(
                    np.sum(np.square(attacker.state.p_pos - self.world.doorLoc))
                )
            else:
                dist = -1
            distances.append(dist)
        # otherwise not done
        return distances

    # get reward for a particular agent
    def _get_reward(self, agent):
        if self.reward_callback is None:
            return 0.0
        return self.reward_callback(agent)

    # set env action for a particular agent
    def _set_action(self, action, agent, action_space, time=None):

        agent.action.u = np.zeros(self.world.dim_p)
        agent.action.c = np.zeros(self.world.dim_c)
        # process action
        # if isinstance(action_space, MultiDiscrete):
        #     act = []
        #     size = action_space.high - action_space.low + 1
        #     index = 0
        #     for s in size:
        #         act.append(action[index:(index+s)])
        #         index += s
        #     action = act
        # else:
        action = [action]

        if agent.movable:
            # print('self.discrete_action_input', self.discrete_action_input) # True
            # physical action
            if self.discrete_action_input:
                agent.action.u = np.zeros(
                    self.world.dim_p
                )  ## We'll use this now for Graph NN
                # process discrete action
                ## if action[0] == 0, then do nothing

                if self.multi_discrete and (
                    not agent.attacker or not self.using_hard_coded_paths
                ):
                    moving_action = action[0][0]
                    shoot_rotation_action = action[0][1]
                    if moving_action == 1:
                        agent.action.u[0] = np.cos(agent.state.p_ang)
                        agent.action.u[1] = np.sin(agent.state.p_ang)
                    if moving_action == 2:
                        agent.action.u[0] = -np.cos(agent.state.p_ang)
                        agent.action.u[1] = -np.sin(agent.state.p_ang)

                    agent.action.shoot = False
                    if shoot_rotation_action == 1:
                        agent.action.u[2] = +agent.max_rot
                    elif shoot_rotation_action == 2:
                        agent.action.u[2] = -agent.max_rot
                    elif shoot_rotation_action == 3:
                        agent.action.shoot = agent.state.bullets_left > 0

                else:
                    if action[0] == 1:
                        agent.action.u[0] = +1.0
                    if action[0] == 2:
                        agent.action.u[0] = -1.0
                    if action[0] == 3:
                        agent.action.u[1] = +1.0
                    if action[0] == 4:
                        agent.action.u[1] = -1.0
                    if action[0] == 5:
                        agent.action.u[2] = +agent.max_rot
                    if action[0] == 6:
                        agent.action.u[2] = -agent.max_rot

                    agent.action.shoot = (
                        True
                        if action[0] == 7 and agent.state.bullets_left > 0
                        else False
                    )
                if agent.action.shoot:
                    agent.state.bullets_left += -1

            else:
                if self.discrete_action_space:  ## this was begin used in PR2 Paper
                    # print('action', action)
                    agent.action.u[0] += (
                        action[0][1] - action[0][2]
                    )  ## each is 0 to 1, so total is -1 to 1
                    agent.action.u[1] += action[0][3] - action[0][4]  ## same as above

                    ## simple shooting action
                    agent.action.shoot = (
                        True
                        if action[0][6] > 0.5 and agent.state.bullets_left >= 0
                        else False
                    )  # a number greater than 0.5 would mean shoot

                    if agent.action.shoot:
                        agent.state.bullets_left += -1

                    ## simple rotation model
                    agent.action.u[2] = 2 * (action[0][5] - 0.5) * agent.max_rot

                else:
                    agent.action.u = action[0]
            sensitivity = 5.0  # default if no value specified for accel
            if agent.accel is not None:
                sensitivity = agent.accel
            agent.action.u[:2] *= sensitivity

            ## remove used actions
            action = action[1:]

        if not agent.silent:
            # communication action
            if self.discrete_action_input:
                agent.action.c = np.zeros(self.world.dim_c)
                agent.action.c[action[0]] = 1.0
            else:
                agent.action.c = action[0]
            action = action[1:]

        # make sure we used all elements of action
        assert len(action) == 0

    # reset rendering assets
    def _reset_render(self):
        self.render_geoms = None
        self.render_geoms_xform = None

    def render_pygame(
        self,
        attn_list=None,
        mode="rgbarray",
        close=False,
        agent_id=None,
        render_multiple=False,
    ):
        output = []
        self.shot = False
        screen = self.screen if render_multiple else self.visualizing_screen
        screen.fill((0, 0, 0))
        if render_multiple:
            half_res = self.render_resolution / 2
        else:
            half_res = 700 / 2

        # Fort rendering
        if self.rectangle:

            def lentopix(xy):
                return (half_res * (xy[0] + 1.0), half_res * (xy[1] + 2.2246))

        else:

            def lentopix(xy):
                return (half_res * (xy[0] + 1.0), half_res * (xy[1] + 1.0))

        def corcolor(r, g=None, b=None, innt=1):
            if g is None:
                return (255.0 * r[2], 255.0 * r[1], 255.0 * r[0])
            else:
                return (255.0 * b, 255.0 * g, 255.0 * r, innt)

        pygame.draw.circle(
            screen,
            (255, 255, 0),
            lentopix(self.world.doorLoc),
            half_res * self.world.fortDim,
        )

        def draw_polygon_alpha(surface, color, points):
            lx, ly = zip(*points)
            min_x, min_y, max_x, max_y = min(lx), min(ly), max(lx), max(ly)
            target_rect = pygame.Rect(min_x, min_y, max_x - min_x, max_y - min_y)
            shape_surf = pygame.Surface(target_rect.size, pygame.SRCALPHA)
            pygame.draw.polygon(
                shape_surf, color, [(x - min_x, y - min_y) for x, y in points]
            )
            surface.blit(shape_surf, target_rect)

        def draw_circle_alpha(surface, color, center, radius):
            target_rect = pygame.Rect(center, (0, 0)).inflate((radius * 2, radius * 2))
            shape_surf = pygame.Surface(target_rect.size, pygame.SRCALPHA)
            pygame.draw.circle(shape_surf, color, (radius, radius), radius, width=2)
            surface.blit(shape_surf, target_rect)

        # Agent rendering
        shift_i = []
        for i, agent in enumerate(self.world.agents):
            # Agent Gun
            shift = (
                0.9
                * agent.size
                * np.array([np.cos(agent.state.p_ang), np.sin(agent.state.p_ang)])
            )
            shift_i.append(shift)
            pygame.draw.circle(
                screen,
                corcolor(agent.color),
                lentopix(agent.state.p_pos + shift),
                half_res * 0.5 * agent.size,
            )
            # Agent Body
            pygame.draw.circle(
                screen,
                corcolor(agent.color),
                lentopix(agent.state.p_pos),
                half_res * agent.size,
            )
            # Black Boarder
            # pygame.draw.circle(screen, (0,0,0), lentopix(agent.state.p_pos), 350.*agent.size+1, width=5)
            draw_circle_alpha(
                screen,
                (0, 0, 0, 200),
                lentopix(agent.state.p_pos),
                half_res * agent.size,
            )

            if agent.action.shoot:
                v = self.world.get_tri_pts_arr(agent)[:2, :].transpose()
                v = [lentopix(vi) for vi in v]
                # pygame.draw.polygon(screen, corcolor(agent.color[0],agent.color[1],agent.color[2],0.), v)
                draw_polygon_alpha(
                    screen,
                    corcolor(
                        agent.color[0], agent.color[1], agent.color[2], 255.0 * 0.3
                    ),
                    v,
                )
        if not render_multiple:
            image = np.frombuffer(screen.get_view("1"), dtype="u1").copy()
            if self.rectangle:
                output = image.reshape((1557, 700, 4))
            else:
                output = image.reshape((700, 700, 4))
            output = output[::-1, :, 0:3]
            return output

        for i, agent in enumerate(self.world.agents):
            shift = shift_i[i]
            # White
            pygame.draw.circle(
                screen,
                (255, 255, 255),
                lentopix(agent.state.p_pos + shift),
                half_res * 0.5 * agent.size,
            )
            pygame.draw.circle(
                screen,
                (255, 255, 255),
                lentopix(agent.state.p_pos),
                half_res * agent.size,
            )
            # save image
            image = np.frombuffer(screen.get_view("1"), dtype="u1").copy()
            # original agent color returned
            pygame.draw.circle(
                screen,
                corcolor(agent.color),
                lentopix(agent.state.p_pos + shift),
                half_res * 0.5 * agent.size,
            )
            pygame.draw.circle(
                screen,
                corcolor(agent.color),
                lentopix(agent.state.p_pos),
                half_res * agent.size,
            )
            output.append(image)

        output = np.stack(output, axis=0)
        output = output.reshape(
            (self.n, self.render_resolution, self.render_resolution, 4)
        )
        output = output[:, ::-1, :, 0:3]

        return output

    # render environment
    def render(
        self,
        attn_list=None,
        mode="human",
        close=False,
        agent_id=None,
        render_multiple=False,
    ):
        # attn_list = [[teamates_attn, opp_attn] for each team]
        if render_multiple:
            return self.render_pygame(attn_list, mode, close, agent_id, True)
        return self.render_pygame(attn_list, mode, close, agent_id, False)

    # create receptor field locations in local coordinate frame
    def _make_receptor_locations(self, agent):
        receptor_type = "polar"
        range_min = 0.05 * 2.0
        range_max = 1.00
        dx = []
        # circular receptive field
        if receptor_type == "polar":
            for angle in np.linspace(-np.pi, +np.pi, 8, endpoint=False):
                for distance in np.linspace(range_min, range_max, 3):
                    dx.append(distance * np.array([np.cos(angle), np.sin(angle)]))
            # add origin
            dx.append(np.array([0.0, 0.0]))
        # grid receptive field
        if receptor_type == "grid":
            for x in np.linspace(-range_max, +range_max, 5):
                for y in np.linspace(-range_max, +range_max, 5):
                    dx.append(np.array([x, y]))
        return dx


# vectorized wrapper for a batch of multi-agent environments
# assumes all environments have the same observation and action space
class BatchMultiAgentEnv(gym.Env):
    metadata = {"runtime.vectorized": True, "render.modes": ["human", "rgb_array"]}

    def __init__(self, env_batch):
        self.env_batch = env_batch

    @property
    def n(self):
        return np.sum([env.n for env in self.env_batch])

    @property
    def action_space(self):
        return self.env_batch[0].action_space

    @property
    def observation_space(self):
        return self.env_batch[0].observation_space

    def step(self, action_n, time):
        obs_n = []
        reward_n = []
        done_n = []
        info_n = {"n": []}
        i = 0
        for env in self.env_batch:
            obs, reward, done, _ = env.step(action_n[i : (i + env.n)], time)
            i += env.n
            obs_n += obs
            # reward = [r / len(self.env_batch) for r in reward]
            reward_n += reward
            done_n += done
        return obs_n, reward_n, done_n, info_n

    def reset(self):
        obs_n = []
        for env in self.env_batch:
            obs_n += env.reset()
        return obs_n

    # render environment
    def render(self, mode="human", close=True):
        results_n = []
        for env in self.env_batch:
            results_n += env.render(mode, close)
        return results_n
