import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario


class Scenario(BaseScenario):
    def make_world(self):
        world = World()
        # set any world properties first
        world.dim_c = 2
        num_agents = 4
        num_landmarks = 4
        world.collaborative = False  #True
        # add agents
        world.agents = [Agent() for i in range(num_agents)]
        for i, agent in enumerate(world.agents):
            agent.name = 'agent %d' % i
            agent.index = i  ###int
            agent.collide = True
            agent.silent = True
            agent.size = 0.13
            # if i == 0:
            #     agent.size = 0.13
            # elif i == 1:
            #     agent.size = 0.15
            # else:
            #     agent.size = 0.17
            agent.tp = 0
            agent.live = 1
        # add landmarks
        world.landmarks = [Landmark() for i in range(num_landmarks)]
        for i, landmark in enumerate(world.landmarks):
            landmark.name = 'landmark %d' % i
            landmark.index = i  ###int
            landmark.collide = False
            landmark.movable = False
        # neighbors
        world.neigh_comm = 1.0
        world.sight = world.neigh_comm
        # make initial conditions
        self.reset_world(world)
        return world

    def reset_world(self, world):
        # random properties for agents & set goal landmark
        for i, agent in enumerate(world.agents):
            # agent.color = np.array([0.15, 0.55, 0.65])
            if i == 0:
                agent.color = np.array([0.8, 0.3, 0.3])  #red
            elif i == 1:
                agent.color = np.array([0.35, 0.35, 0.85])  #purple
            elif i == 2:
                agent.color = np.array([0.05, 0.45, 0.55])  #blue
            else:
                agent.color = np.array([0.25, 0.25, 0.25])
            # agent.goal_a = world.landmarks[i]
        # random properties for landmarks
        for i, landmark in enumerate(world.landmarks):
            landmark.color = world.agents[i].color
        # set random initial states
        for i, agent in enumerate(world.agents):
            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.state.c = np.zeros(world.dim_c)
        # x_coord = np.random.choice(np.arange(-1, 1, 0.2), len(world.agents), replace=False)
        # y_coord = np.random.choice(np.arange(-1, 1, 0.2), len(world.agents), replace=False)
        for i, landmark in enumerate(world.landmarks):
            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
            # landmark.state.p_pos = np.array([x_coord[i], y_coord[i]])

    def benchmark_data(self, agent, world):
        # returns data for benchmarking purposes
        return (rew, collisions, min_dists, occupied_landmarks)

    def is_collision(self, agent1, agent2):
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = agent1.size + agent2.size
        return True if dist < dist_min else False

    def is_occupy(self, agent, landmark):
        d = np.sqrt(np.sum(np.square(landmark.state.p_pos - agent.state.p_pos)))
        return True if d <= 0.05 else False

    def reward(self, agent, world):  ###Local Reward
        # Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions
        rew = 0
        ### for landmarks
        for a in world.agents:
            dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for l in world.landmarks]
            rew -= min(dists)
            # occ = [self.is_occupy(agent, l) for l in world.landmarks]
            # if any(occ):
            #     rew += 2
            for other in world.agents:
                if other is a: continue
                if self.is_collision(a, other):
                    rew -= 1
        ### for each agent
        # dists = [np.sqrt(np.sum(np.square(agent.state.p_pos - l.state.p_pos))) for l in world.landmarks]
        # rew -= min(dists)
        # for other in world.agents:
        #     if other is agent: continue
        #     if self.is_collision(agent, other):
        #         rew -= 1
        return rew

    def observation(self, agent, world):
        # get positions of all entities in this agent's reference frame
        entity_pos = []
        for entity in world.landmarks:  # world.entities:
            entity_pos.append(entity.state.p_pos - agent.state.p_pos)
        other_pos = []
        other_vel = []
        for other in world.agents:
            if other is agent: continue
            other_pos.append(other.state.p_pos - agent.state.p_pos)
            other_vel.append(other.state.p_vel)
         
        return np.concatenate([agent.state.p_pos] + [agent.state.p_vel] + \
            entity_pos + other_pos + other_vel)
