import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
import os
import random
SIGHT = 0.5


class Scenario(BaseScenario):
    def __init__(self):
        self.n_landmarks = 8
        self.n_forests = 0
        self.num_agents = 8
        self.alpha = 0
        self.ratio = 1
        self.size = self.ratio
        self.max_neighbor = 3
        self.color_lst = [
            np.array([0.35, 0.35, 0.85]),  #purple
            np.array([0.1, 0.45, 0.1]),  #green
            np.array([0.05, 0.45, 0.55])  #blue
        ]

    def make_world(self):
        world = World()
        world.size = self.ratio
        world.neigh_comm = 1.0
        world.sight = 1.0
        world.max_neighbor = 3
        # set any world properties first
        world.dim_c = 2
        world.num_agents = self.num_agents
        world.num_landmarks = self.n_landmarks
        # add agents
        world.agents = [Agent() for i in range(world.num_agents)]
        ID_a = np.eye(world.num_agents)

        for i, agent in enumerate(world.agents):
            agent.name = 'agent %d' % i
            agent.id = i
            agent.id_hot = ID_a[i]
            agent.collide = True
            agent.silent = True
            agent.tp = 0 if i==0 else 9
            agent.size = 0.1 if i==0 else 0.1
            agent.accel = 5 if i==0 else 5
            agent.max_speed = 4 if i==0 else 4
            agent.live = 1 if i==0 else 1
            agent.sight = world.sight if i==0 else world.sight

        # make initial conditions
        world.landmarks = [Landmark() for i in range(world.num_landmarks)]
        for i, landmark in enumerate(world.landmarks):
            landmark.sight = 0
            landmark.name = 'landmarks %d' % i
            landmark.collide = False
            landmark.movable = False
            landmark.size = 0.05
            landmark.live = 1
            landmark.boundary = False

        # neighbors
        self.reset_world(world)
        return world


    def reset_world(self, world):
        # random properties for landmarks
        for i, landmark in enumerate(world.landmarks):
            landmark.color = self.color_lst[np.mod(i,3)]
        # random properties for agents
        for i, agent in enumerate(world.agents):
            agent.color = world.landmarks[i].color
            agent.goal_a = world.landmarks[i]
        # set random initial states
        coord = np.array(
            [[[-1, 0], [0, 1]],
            [[0, 1], [0, 1]],
            [[-1, 0], [-1, 0]],
            [[0, 1], [-1, 0]]]
        )
        coord_idx = np.random.choice(4, 4, replace=False)
        for i, agent in enumerate(world.agents):
            # agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
            agent.state.p_pos = (
                np.random.uniform(
                    coord[coord_idx[np.mod(i+1,4)]][0][0], coord[coord_idx[np.mod(i+1,4)]][0][1]),
                np.random.uniform(
                    coord[coord_idx[np.mod(i+1,4)]][1][0], coord[coord_idx[np.mod(i+1,4)]][1][1])
            )
            agent.state.p_pos = np.array(agent.state.p_pos)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.neib = np.eye(self.num_agents)[i].astype(int)
        x_coord = np.random.choice(np.arange(-0.8, 0.8, 0.1), 8, replace=False)
        y_coord = np.random.choice(np.arange(-0.8, 0.8, 0.1), 8, replace=False)
        for i, landmark in enumerate(world.landmarks):
            # landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)
            # landmark.state.p_pos = (
            #     np.random.uniform(
            #         coord[coord_idx[np.mod(i,4)]][0][0], coord[coord_idx[np.mod(i,4)]][0][1]),
            #     np.random.uniform(
            #         coord[coord_idx[np.mod(i,4)]][1][0], coord[coord_idx[np.mod(i,4)]][1][1])
            # )
            # landmark.state.p_vel = np.zeros(world.dim_p)
            landmark.state.p_pos = np.array([x_coord[i], y_coord[i]])

    def is_collision(self, agent1, agent2):
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = agent1.size + agent2.size
        return True if dist <= dist_min else False

    # return all agents that are not adversaries
    def good_agents(self, world):
        return [agent for agent in world.agents if not agent.adversary]

    # return all adversarial agents
    def adversaries(self, world):
        return [agent for agent in world.agents if agent.adversary]

    def done(self, agent, world):
        return 0

    def benchmark_data(self, agent, world):
        coll = 0
        dist_to_goal = 0
        for a in world.agents:
            for b in world.agents:
                if b is a: continue
                if self.is_collision(a, b):
                    coll += 1
            dist_to_goal += np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))
        return (coll/2, dist_to_goal)


    def neighb(self, agent, world):
        dists = []
        num_neighbor = 0
        for a in world.agents:
            if a is agent: continue
            dists.append((a.id, np.sum(np.square(agent.state.p_pos - a.state.p_pos))))
        for other_id, other_dist in dists:
            if other_dist <= world.neigh_comm and num_neighbor < world.max_neighbor-1:
                agent.neib[other_id] = 1
                num_neighbor += 1
        return agent.neib
    

    def reward(self, agent, world):
        rew = 0

        dist = np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))
        rew -= dist

        if agent.collide:  #碰撞惩罚过大会影响agent占领landmark的意图
            for other in world.landmarks:
                if other is agent: continue
                if other.collide:
                    if self.is_collision(agent, other):
                        if agent.id <=2:
                            rew -= 50
                        elif agent.id <=4:
                            rew -= 100
                        elif agent.id <=6:
                            rew -= 150      
        
        return rew


    def observation(self, agent, world):
        dist = []
        for i, other in enumerate(world.agents):
            if other is agent: continue
            dist.append((i, np.sum(np.square(agent.state.p_pos - other.state.p_pos))))
        dist = sorted(dist, key = lambda t: t[1])
        other_pos = []
        other_vel = []
        other_id = []
        for i, other_dist in dist:
            if other_dist <= agent.sight:
                other_pos.append(world.agents[i].state.p_pos - agent.state.p_pos)
                other_vel.append(world.agents[i].state.p_vel)
                other_id.append(world.agents[i].id_hot)
            else:
                other_pos.append([0,0])
                other_vel.append([0,0])
                other_id.append(np.zeros(self.num_agents))


        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + \
            [agent.goal_a.state.p_pos - agent.state.p_pos] + \
            other_id + other_pos + other_vel)