import gym
from gym import error, spaces, utils
from gym.utils import seeding
from env.gym_fortattack.core import World, Agent, EntityState
import numpy as np
import time
import math


def angle_between(p1, p2, fang):
    ang1 = np.arctan2(p2[0] - p1[0], p2[1] - p1[1])
    ang1 = (360 - np.rad2deg(ang1) + 90) % 360
    diff = np.rad2deg(fang) - ang1
    if diff > 180:
        diff -= 360
    if diff < -180:
        diff += 360
    return np.deg2rad(diff)


def distance_between(p1, p2):
    return math.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2)


class FortAttackEnvV1(gym.Env):
    metadata = {"render.modes": ["human"]}

    def __init__(
        self,
        numGuards,
        numAttackers,
        max_rot,
        random_starting_rot,
        rot_rew_param,
        dist_rew_param,
        default_spawn_pos,
        rectangle,
    ):
        # environment will have guards(green) and attackers(red)
        # red bullets - can hurt green agents, vice versa
        # single hit - if hit once agent dies
        self.world = World(rectangle=rectangle)

        self.world.fortDim = 0.15  # radius
        if rectangle:
            self.world.doorLoc = np.array([0.0, 1.5])
        else:
            self.world.doorLoc = np.array([0, 0.8])
        # initial number of guards, attackers and bullets
        self.world.numGuards = numGuards
        self.world.numAttackers = numAttackers
        self.world.numBullets = 0
        self.world.numAgents = self.world.numGuards + self.world.numAttackers
        self.rot_rew_param = rot_rew_param
        self.dist_rew_param = dist_rew_param
        self.spawn_pos = default_spawn_pos
        self.rectangle = rectangle
        (
            self.world.numAliveGuards,
            self.world.numAliveAttackers,
            self.world.numAliveAgents,
        ) = (self.world.numGuards, self.world.numAttackers, self.world.numAgents)
        self.world.atttacker_reached = (
            False  ## did any attacker succeed to reach the gate?
        )
        landmarks = []  # as of now no obstacles, landmarks
        self.num_shots = None

        self.world.agents = [
            Agent() for i in range(self.world.numAgents)
        ]  # first we have the guards and then we have the attackers
        for i, agent in enumerate(self.world.agents):
            agent.name = "agent %d" % (i + 1)
            agent.collide = True
            agent.silent = True
            agent.attacker = False if i < self.world.numGuards else True
            # agent.shootRad = 0.8 if i<self.world.numGuards else 0.6
            agent.accel = 3.0  ## guards and attackers have same speed and accel
            agent.max_speed = 3.0  ## used in integrate_state() inside core.py. slowing down so that bullet can move fast and still it doesn't seem that the bullet is skipping steps
            if not agent.attacker:
                agent.accel = 2.0
                agent.max_speed = 2.0
            agent.max_rot = max_rot  # 0.17 ## approx 10 degree
        self.max_rot = max_rot

        self.viewers = [None]
        self.render_geoms = None
        self.shared_viewer = True
        self.world.time_step = 0
        self.world.max_time_steps = None  #  set inside malib/environments/fortattack
        self.world.vizDead = True  # whether to visualize the dead agents
        self.world.vizAttn = True  # whether to visualize attentions
        self.world.gameResult = np.array(
            [0, 0, 0]
        )  #  [all attackers dead, max time steps, attacker reached fort]
        self.random_starting_rot = random_starting_rot
        self.reset_world()

    def add_num_bullets(self, num_bullets, attacker_can_fire):
        for i, agent in enumerate(self.world.agents):
            if not attacker_can_fire and agent.attacker:
                agent.state.max_bullets = 0
                agent.state.bullets_left = 0
            else:
                agent.state.max_bullets = num_bullets
                agent.state.bullets_left = num_bullets

    def reset_world(self, spawn_pos=None):
        # light green for guards and light red for attackers
        if spawn_pos is not None:
            self.spawn_pos = spawn_pos
        self.world.time_step = 0
        self.world.bullets = []  ##
        self.world.numAliveAttackers = self.world.numAttackers
        self.world.numAliveGuards = self.world.numGuards
        self.world.numAliveAgents = self.world.numAgents
        self.world.gameResult[:] = 0

        for i, agent in enumerate(self.world.agents):
            agent.alive = True
            agent.color = (
                np.array([0.0, 1.0, 0.0])
                if not agent.attacker
                else np.array([1.0, 0.0, 0.0])
            )
            agent.state.p_vel = np.zeros(self.world.dim_p - 1)  ##
            agent.state.c = np.zeros(self.world.dim_c)
            agent.state.p_ang = np.pi / 2 if agent.attacker else 3 * np.pi / 2
            if self.random_starting_rot:
                agent.state.p_ang += np.random.uniform(-np.pi / 4, np.pi / 4)
            agent.state.bullets_left = agent.state.max_bullets

            xMin, xMax, yMin, yMax = self.world.wall_pos
            # now we will set the initial positions
            # attackers start from far away
            if self.spawn_pos == "opposite_ends":  # agent.attacker:
                if not agent.attacker:
                    agent.state.p_pos = np.concatenate(
                        (
                            np.random.uniform(xMin, xMax, 1),
                            np.random.uniform(0.33 * yMax, yMax, 1),
                        )
                    )
                else:
                    agent.state.p_pos = np.concatenate(
                        (
                            np.random.uniform(xMin, xMax, 1),
                            np.random.uniform(yMin, 0.33 * yMin, 1),
                        )
                    )

            # guards start near the door
            elif self.spawn_pos == "random":
                # xfinal = np.random.uniform(xMin, xMax, 1)
                # yfinal = np.random.uniform(0.33 * yMax, yMax, 1)
                position_not_found = True
                while position_not_found:
                    xfinal = np.random.uniform(xMin, xMax, 1)
                    yfinal = np.random.uniform(yMin, yMax, 1)
                    if (
                        np.linalg.norm(np.array([xfinal, yfinal]) - self.world.doorLoc)
                        > 1.5
                        or not agent.attacker
                    ):
                        position_not_found = False
                agent.state.p_pos = np.concatenate(
                    (
                        xfinal,
                        yfinal,
                    )
                )

            agent.numHit = 0  # overall in one episode
            agent.numWasHit = 0
            agent.hit = False  # in last time step
            agent.wasHit = False
            agent.attacker_pairing = -1
            agent.attacker_pairing_time = 0

        # random properties for landmarks
        for i, landmark in enumerate(self.world.landmarks):
            landmark.color = np.array([0.25, 0.25, 0.25])

        for i, landmark in enumerate(self.world.landmarks):
            if not landmark.boundary:
                landmark.state.p_pos = np.random.uniform(-0.9, +0.9, self.world.dim_p)
                landmark.state.p_vel = np.zeros(self.world.dim_p)

    def reward(self, agent):
        if agent.alive or agent.justDied:
            main_reward = (
                self.attacker_reward(agent)
                if agent.attacker
                else self.guard_reward(agent)
            )
        else:
            main_reward = 0
        return main_reward

    def attacker_reward(self, agent):
        rew0, rew1, rew2, rew3, rew4, rew5 = 0, 0, 0, 0, 0, 0

        # dead agents are not getting reward just when they are dead
        # # Attackers get reward for being close to the door
        distToDoor = np.sqrt(np.sum(np.square(agent.state.p_pos - self.world.doorLoc)))
        if agent.prevDist is not None:
            rew0 = 2 * (agent.prevDist - distToDoor)
            # print('rew0', rew0, 'fortattack_env_v1.py')
        # Attackers get very high reward for reaching the door
        th = self.world.fortDim
        if distToDoor < th:
            rew1 = 10
            self.world.atttacker_reached = True

        # attacker gets -ve reward for using laser
        if agent.action.shoot:
            rew2 = -2

        # gets positive reward for hitting a guard??
        if agent.hit:
            rew3 = +3

        # gets negative reward for being hit by a guard
        if agent.wasHit:
            rew4 = -3

        # high negative reward if all attackers are dead
        if self.world.numAliveAttackers == 0:
            rew5 = -10

        rew = rew0 + rew1 + rew2 + rew3 + rew4 + rew5
        agent.prevDist = distToDoor.copy()
        # print('attacker_reward', rew1, rew2, rew3, rew4, rew)
        return rew

    def guard_reward(self, agent):
        # guards get reward for keeping all attacker away
        rew0, rew1, rew2, rew3, rew4, rew5, rew6, rew7, rew8, rew9, rew10 = (
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
        )

        # # high negative reward for leaving the fort
        selfDistToDoor = np.sqrt(
            np.sum(np.square(agent.state.p_pos - self.world.doorLoc))
        )
        # if selfDistToDoor>0.3:
        #     rew0 = -2

        # negative reward for going away from the fort
        # if agent.prevDist is not None:
        #     if selfDistToDoor>0.3 and agent.prevDist<=0.3:
        #         rew0 = -1
        #     elif selfDistToDoor<=0.3 and agent.prevDist>0.3:
        #         rew0 = 1

        # rew1 = 20*(agent.prevDist - selfDistToDoor)

        # print('rew1', rew1, 'fortattack_env_v1.py')
        # rew1 = -0.1*selfDistToDoor

        # negative reward if attacker comes closer
        # make it exponential
        if self.world.numAliveAttackers != 0:
            minDistToDoor = np.min(
                [
                    np.sqrt(
                        np.sum(np.square(attacker.state.p_pos - self.world.doorLoc))
                    )
                    for attacker in self.world.alive_attackers
                ]
            )
            # rew2 = -0.5 * np.exp(-((minDistToDoor * 3) ** 2))
            rew2 = 0

            # high negative reward if attacker reaches the fort
            th = self.world.fortDim
            if minDistToDoor < th:
                rew3 = -5.0  # -10

        # guard gets negative reward for using laser
        if agent.action.shoot:
            # rew4 = -0.1
            rew4 = -0.1  # -0.1

        # gets reward for hitting an attacker
        if agent.hit:
            rew5 = 3.0  # 3

        # guard gets -ve reward for being hit by laser
        if agent.wasHit:
            rew6 = -3.0

        # high positive reward if all attackers are dead
        if self.world.numAliveAttackers == 0:
            # if agent.hit:
            rew7 = 0.0  # 10

        if len(self.world.alive_attackers) > 0:
            # make sure there's no bug by running single agent
            if agent.attacker_pairing == -1 or not agent.attacker_pairing.alive:

                attacker_distances = np.array(
                    [
                        np.sum(np.square(attacker.state.p_pos - agent.state.p_pos))
                        for attacker in self.world.alive_attackers
                    ]
                )
                closest_attacker = self.world.alive_attackers[
                    np.argmin(attacker_distances)
                ]
                agent.attacker_pairing = closest_attacker
                agent.attacker_pairing_time = 0
            else:
                agent.attacker_pairing_time += 1
            chosen_attacker = agent.attacker_pairing

            angle = angle_between(
                agent.state.p_pos, chosen_attacker.state.p_pos, agent.state.p_ang
            )

            if abs(angle) < self.max_rot:
                rew8 = 0
            else:
                rew8 = -abs(angle) * self.rot_rew_param

            attacker_distance = np.sqrt(
                np.sum(np.square(chosen_attacker.state.p_pos - agent.state.p_pos))
            )
            if attacker_distance < 1.0:
                rew9 = 0
            else:
                rew9 = -attacker_distance * self.dist_rew_param

        # # small positive reward at every time step
        # rew8 = 10/self.world.max_time_steps

        if False and agent.action.shoot and agent.state.bullets_left == 0:
            rew9 += -5.0

        # rew10 = max(-np.exp(agent.attacker_pairing_time / 5.0) * 0.01, -0.5)

        # if self.world.time_step == self.world.max_time_steps-1:
        #    rew8 = -5

        rew = (
            rew0 + rew1 + rew2 + rew3 + rew4 + rew5 + rew6 + rew7 + rew8 + rew9 + rew10
        )
        # print('guard_reward', rew1, rew2, rew3, rew4, rew)
        agent.prevDist = selfDistToDoor.copy()
        return rew

    def concept(self, agent, world):

        if agent.attacker:
            return {}
        else:
            concepts = {}
            # can_shoot_ordinal
            concepts["can_shoot_ordinal"] = []
            A = world.get_tri_pts_arr(agent)
            for attacker in world.attackers:
                if not attacker.alive:
                    concepts["can_shoot_ordinal"].append([1, 0])
                else:
                    if world.laser_hit(A, attacker):
                        concepts["can_shoot_ordinal"].append([0, 1])
                    else:
                        concepts["can_shoot_ordinal"].append([1, 0])
            # agent_targeting_ordinal
            concepts["agent_targeting_ordinal"] = [
                0 for _ in range(len(world.attackers))
            ]
            if type(agent.attacker_pairing) == int:
                concepts["agent_targeting_ordinal"][0] = 1
            else:
                concepts["agent_targeting_ordinal"][agent.attacker_pairing.num - 1] = 1

            # attacker_strategy
            concepts["attacker_strategy"] = None  # will get filled in later

            # relative_orientation
            concepts["relative_orientation"] = []
            for attacker in world.attackers:
                if not attacker.alive:
                    concepts["relative_orientation"].append(-1)
                else:
                    concepts["relative_orientation"].append(
                        angle_between(
                            agent.state.p_pos, attacker.state.p_pos, agent.state.p_ang
                        )
                    )

            # relative_distance
            concepts["distance_between"] = []
            for attacker in world.attackers:
                if not attacker.alive:
                    concepts["distance_between"].append(-1)
                else:
                    concepts["distance_between"].append(
                        distance_between(agent.state.p_pos, attacker.state.p_pos)
                    )

            return concepts

    def observation(self, agent, world, gt=False):
        # print('agent name', agent.name)
        # if not agent.alive:
        #     return(np.array([]))
        # else:
        # get positions of all entities in this agent's reference frame
        entity_pos = []
        for entity in world.landmarks:
            if not entity.boundary:
                entity_pos.append(entity.state.p_pos - agent.state.p_pos)

        orien = [[agent.state.p_ang]]
        # [np.array([np.cos(agent.state.p_ang), np.sin(agent.state.p_ang)])]

        # communication of all other agents
        # comm = []
        # other_pos = []
        # other_vel = []
        # other_orien = []
        # other_shoot = []
        # for other in world.agents:
        #     if other is agent: continue
        #     comm.append(other.state.c)
        #     other_pos.append(other.state.p_pos - agent.state.p_pos)
        #     ## if not other.attacker:
        #     other_vel.append(other.state.p_vel)
        #     rel_ang = other.state.p_ang - agent.state.p_ang
        #     other_orien.append(np.array([np.cos(rel_ang), np.sin(rel_ang)]))
        #     other_shoot.append(np.array([other.action.shoot]).astype(float))
        # print('obs')
        # print([agent.state.p_pos])
        # print([agent.state.p_vel])
        # print(orien)
        # print(entity_pos)
        # print(other_pos)
        # print(other_vel)
        # print(other_orien)
        # print(other_shoot)
        # print(len(other_orien), other_shoot.shape)
        # print(np.concatenate([agent.state.p_pos] + [agent.state.p_vel] + orien + entity_pos + other_pos + other_vel + other_orien + other_shoot))
        # [[int(agent.alive)]]+

        # print(np.shape(np.concatenate([[int(agent.alive)]]+[agent.state.p_pos] + [agent.state.p_vel] + orien + entity_pos + other_pos + other_vel + other_orien + other_shoot)))

        # return np.concatenate([[int(agent.alive)]]+[agent.state.p_pos] + [agent.state.p_vel] + orien + entity_pos + other_pos + other_vel + other_orien + other_shoot)

        if agent.action.shoot:
            firing_coordinates = (
                world.get_tri_pts_arr(agent)[:2, :].transpose().reshape(-1)
            )
        else:
            firing_coordinates = (
                np.zeros(
                    6,
                )
                + 2.0
            )
        firing_coordinates = 2.0 * (firing_coordinates + 1.0) / 3.0 - 1.0

        if not agent.alive:
            scaled_values = np.array(
                [0, 2, 2, 2 * np.pi, agent.max_speed, agent.max_speed, 0]
            )
        else:
            scaled_values = np.concatenate(
                [[agent.alive]]
                + [agent.state.p_pos]
                + orien
                + [agent.state.p_vel]
                + entity_pos
                + [[agent.state.bullets_left]]
            )
        if not gt:
            # Makes alive = 1 and dead = 0
            scaled_values[0] = 2.0 * (scaled_values[0] - 0.5)
            if self.rectangle:
                # x position
                scaled_values[1] = 2.0 * (scaled_values[1] + 1) / 3.0 - 1.0
                # y position
                scaled_values[2] = 2.0 * (scaled_values[2] + 2.225) / 5.0 - 1.0
            else:
                # x position
                scaled_values[1] = 2.0 * (scaled_values[1] + 1) / 3.0 - 1.0
                # y position
                scaled_values[2] = 2.0 * (scaled_values[2] + 1) / 3.0 - 1.0
            # orientation
            scaled_values[3] = (scaled_values[3] - np.pi) / np.pi
            # x velocity
            scaled_values[4] = scaled_values[4] / agent.max_speed
            # y velocity
            scaled_values[5] = scaled_values[5] / agent.max_speed
            # bullets left
            scaled_values[6] = 2 * (
                (scaled_values[6] / max(agent.state.max_bullets, 1)) - 0.5
            )

        return scaled_values  # np.concatenate([scaled_values, firing_coordinates])

    def info(self, agent, world):
        info_dict = {}
        attacker_pairing = (
            0 if type(agent.attacker_pairing) == int else agent.attacker_pairing.num - 1
        )
        info_dict["attacker_pairing"] = attacker_pairing
        return info_dict
