import numpy as np
from core import World, Agent, Landmark
from scenario import BaseScenario


class Scenario(object):
    def make_world(self, args=None):
        self.num_agents = 4
        self.kaz = knights_archers_zombies_v10.env(
		  spawn_rate=20,
		  num_archers=self.num_agents // 2,
		  num_knights=self.num_agents // 2,
		  max_zombies=60,
		  max_arrows=10,
		  killable_knights=True,
		  killable_archers=True,
		  pad_observation=True,
		  line_death=False,
		  max_cycles=20000,
		  vector_state=True,
		  use_typemasks=False,
		  transformer=False)


        self.reset_world(world)
        return world

    def reset_world(self, world):
        # random properties for agents
        for i, agent in enumerate(world.agents):
            agent.color = np.array([0.35, 0.85, 0.35]) if not agent.adversary else np.array([0.85, 0.35, 0.35])
            # random properties for landmarks
        for i, landmark in enumerate(world.landmarks):
            landmark.color = np.array([0.25, 0.25, 0.25])
        # set random initial states
        for agent in world.agents:
            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.state.c = np.zeros(world.dim_c)
        for i, landmark in enumerate(world.landmarks):
            if not landmark.boundary:
                landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)
                landmark.state.p_vel = np.zeros(world.dim_p)

    def reward(self, agent, world):
        # Agents are rewarded based on minimum agent distance to each landmark
        main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
        return main_reward

    def agent_reward(self, agent, world):
        # Agents are negatively rewarded if caught by adversaries
        rew = 0
        shape = False
        adversaries = self.adversaries(world)
        if shape:  # reward can optionally be shaped (increased reward for increased distance from adversary)
            for adv in adversaries:
                rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))
        if agent.collide:
            for a in adversaries:
                if self.is_collision(a, agent):
                    rew -= 10

        # agents are penalized for exiting the screen, so that they can be caught by the adversaries
        def bound(x):
            if x < 0.9:
                return 0
            if x < 1.0:
                return (x - 0.9) * 10
            return min(np.exp(2 * x - 2), 10)
        for p in range(world.dim_p):
            x = abs(agent.state.p_pos[p])
            rew -= bound(x)

        return rew

    def adversary_reward(self, agent, world):
        # Adversaries are rewarded for collisions with agents
        rew = 0
        shape = False
        agents = self.good_agents(world)
        adversaries = self.adversaries(world)
        if shape:  # reward can optionally be shaped (decreased reward for increased distance from agents)
            for adv in adversaries:
                rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])
        if agent.collide:
            for ag in agents:
                for adv in adversaries:
                    if self.is_collision(ag, adv):
                        rew += 10
        return rew

    def observation(self, agent, world):
        # get positions of all entities in this agent's reference frame
        entity_pos = []
        entity_dist = []
        for entity in world.landmarks:
            dist = np.sqrt(np.sum(np.square(entity.state.p_pos - agent.state.p_pos)))
            entity_dist.append(dist)
            if not entity.boundary and (agent.view_radius >= 0) and dist <= agent.view_radius:
                entity_pos.append(entity.state.p_pos - agent.state.p_pos)
            else:
                entity_pos.append(np.array([0., 0.]))

        distargsort = np.argsort(entity_dist)

        entity_pos = [entity_pos[i] for i in distargsort]
        entity_pos = entity_pos[0:2]
        # communication of all other agents
        comm = []
        prey_pos = []
        prey_dist = []
        prey_vel = []
        other_pos = []
        other_dist = []
        for other in world.agents:
            if other is agent: continue
            dist = np.sqrt(np.sum(np.square(other.state.p_pos - agent.state.p_pos)))
            if agent.view_radius >= 0 and dist <= agent.view_radius:
                comm.append(other.state.c)
                if other.adversary:
                    other_pos.append(other.state.p_pos - agent.state.p_pos)
                    other_dist.append(dist)
                else:
                    prey_pos.append(other.state.p_pos - agent.state.p_pos)
                    prey_dist.append(dist)
                    prey_vel.append(other.state.p_vel)
            else:
                if other.adversary:
                    other_dist.append(dist)
                    other_pos.append(np.array([0., 0.]))
                else:
                    prey_dist.append(dist)
                    prey_pos.append(np.array([0., 0.]))
                    prey_vel.append(np.array([0., 0.]))
        distargsort = np.argsort(other_dist)
        other_pos = [other_pos[i] for i in distargsort]
        other_pos = other_pos[0:2]
        distargsort = np.argsort(prey_dist)
        
        prey_pos = [prey_pos[i] for i in distargsort]
        
        prey_vel = [prey_vel[i] for i in distargsort]
        prey_pos = prey_pos[0:1]
        prey_vel = prey_vel[0:1]
        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + prey_pos + prey_vel)


    def full_observation(self, agent, world):
        # get positions of all entities in this agent's reference frame
        entity_pos = []
        for entity in world.landmarks:
            if not entity.boundary:
                entity_pos.append(entity.state.p_pos - agent.state.p_pos)
        # communication of all other agents
        comm = []
        other_pos = []
        other_vel = []
        for other in world.agents:
            if other is agent: continue
            comm.append(other.state.c)
            other_pos.append(other.state.p_pos - agent.state.p_pos)
            if not other.adversary:
                other_vel.append(other.state.p_vel)
        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)
