import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
import os
import random
SIGHT = 0.5


class Scenario(BaseScenario):
    def __init__(self):
        self.n_good = 0
        self.n_adv = 6
        self.n_landmarks = 0
        self.n_food = 6
        self.n_forests = 0
        self.num_agents = self.n_adv + self.n_good
        self.alpha = 0
        self.good_neigh_dist = 100.0
        self.adv_neigh_dist = 100.0
        self.ratio = 1
        self.size = self.ratio
        self.max_neighbor = 3
        self.max_good_neighbor = self.n_good
        self.max_adv_neighbor = self.n_adv

    def make_world(self):
        world = World()
        world.size = self.ratio
        world.neigh_comm = 0.3
        world.sight = 1.0
        world.good_neigh_dist = self.good_neigh_dist
        world.adv_neigh_dist = self.adv_neigh_dist
        world.max_good_neighbor = self.max_good_neighbor
        world.max_adv_neighbor = self.max_adv_neighbor
        # set any world properties first
        world.dim_c = 2
        num_good_agents = self.n_good
        num_adversaries = self.n_adv
        world.num_good_agents = num_good_agents
        world.num_adversaries = num_adversaries
        num_agents = num_adversaries + num_good_agents
        num_landmarks = self.n_landmarks
        num_food = self.n_food
        # add agents
        world.agents = [Agent() for i in range(num_agents)]
        world.goods =[]
        world.advs =[]
        ID_a = np.eye(num_agents)
        ID_l = np.eye(num_food)

        for i, agent in enumerate(world.agents):
            agent.name = 'agent %d' % i
            agent.id = i
            agent.id_hot = ID_a[i]
            agent.collide = True
            agent.silent = True
            agent.adversary = True if i < num_adversaries else False
            if agent.adversary:
                world.advs.append(agent)
            else:
                world.goods.append(agent)
            if i < 2:
                agent.tp = 0
                agent.tp_hot = np.eye(3)[0]
            elif i < 4:
                agent.tp = 1
                agent.tp_hot = np.eye(3)[1]
            else:
                agent.tp = 2
                agent.tp_hot = np.eye(3)[2]
            agent.size = 0.08 if agent.adversary else 0.08
            agent.accel = 3 if agent.adversary else 3
            agent.max_speed = 3 if agent.adversary else 3
            agent.live = 1 if agent.adversary else 1
            agent.sight = world.sight if agent.adversary else world.sight

        # make initial conditions
        world.food = [Landmark() for i in range(num_food)]
        for i, landmark in enumerate(world.food):
            landmark.sight = 0
            landmark.name = 'food %d' % i
            landmark.id_hot = ID_l[i]
            landmark.tp = world.agents[i].tp
            landmark.tp_hot = world.agents[i].tp_hot
            # if i < 2:
            #     landmark.tp = 0
            #     landmark.tp_hot = np.eye(3)[0]
            # elif i < 4:
            #     landmark.tp = 1
            #     landmark.tp_hot = np.eye(3)[1]
            # else:
            #     landmark.tp = 2
            #     landmark.tp_hot = np.eye(3)[2]
            landmark.collide = False
            landmark.movable = False
            landmark.size = 0.03
            landmark.boundary = False

        world.landmarks =  world.food

        # neighbors
        self.reset_world(world)
        return world


    def reset_world(self, world):
        for i, agent in enumerate(world.agents):
            if agent.tp==0:
                agent.color = np.array([0.45, 0.45, 0.95]) #purple
            elif agent.tp==1:
                agent.color = np.array([0.05, 0.45, 0.55]) #blue
            else:
                agent.color = np.array([0.1, 0.45, 0.1]) #green
            if i==0: agent.color += np.array([0.35, 0.35, 0.35])
            agent.live = 1 if agent.adversary else 1
            agent.last_step = False
            agent.occ_steps = 0
            agent.ismatch = False
        for i, landmark in enumerate(world.landmarks):
            # landmark.color = world.agents[i].color
            if agent.tp==0:
                landmark.color = np.array([0.45, 0.45, 0.95]) #purple
            elif agent.tp==1:
                landmark.color = np.array([0.05, 0.45, 0.55]) #blue
            else:
                landmark.color = np.array([0.1, 0.45, 0.1]) #green
            landmark.live = 1
            landmark.last_step = False
            landmark.occ_steps = 0

        # set random initial states
        for i, agent in enumerate(world.agents):
            agent.state.p_pos = np.random.uniform(-1*0.8, +1*0.8, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.neib = np.eye(self.num_agents)[i].astype(int)
        x_coord = np.random.choice(np.arange(-0.8, 0.8, 0.2), 6, replace=False)
        y_coord = np.random.choice(np.arange(-0.8, 0.8, 0.2), 6, replace=False)
        for i, landmark in enumerate(world.landmarks):
            # landmark.state.p_pos = np.random.uniform(-1*0.9, 1*0.9, world.dim_p)
            landmark.state.p_pos = np.array([x_coord[i], y_coord[i]])
            landmark.state.p_vel = np.zeros(world.dim_p)

    def is_inrange(self, agent1, tar):
        delta_pos = agent1.state.p_pos - tar.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = 0.1
        return True if dist <= dist_min else False
    
    def is_occupy(self, agent, lm):
        d = np.sqrt(np.sum(np.square(agent.state.p_pos - lm.state.p_pos)))
        d_min = agent.size - lm.size
        return True if d <= d_min else False
    
    def is_collision(self, agent1, agent2):
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = agent1.size + agent2.size
        return True if dist <= dist_min else False

    def done(self, agent, world):
        # return 0
        return False if agent.live else True

    def benchmark_data(self, agent, world):
        ismatch = False
        for l in world.landmarks:
            if self.is_occupy(agent, l) and agent.tp == l.tp:
                ismatch = True
        return ismatch


    def neighb(self, agent, world):
        dists = []
        num_neighbor = 0
        for a in world.agents:
            if a is agent: continue
            if not a.live: continue
            dists.append((a.id, np.sum(np.square(agent.state.p_pos - a.state.p_pos))))
        for other_id, other_dist in dists:
            if other_dist <= world.neigh_comm and num_neighbor < self.max_neighbor-1:
                agent.neib[other_id] = 1
                num_neighbor += 1
        return agent.neib


    def reward(self, agent, world):
        # Adversary agents
        rew = 0

        distances = [np.sqrt(np.sum(np.square(agent.state.p_pos - l.state.p_pos))) for l in world.landmarks if l.live]
        if len(distances):
            rew -= min(distances) * 0.2

        for other in world.agents:
            if other is agent: continue
            if self.is_collision(agent, other):
                rew -= 0.2

        occ = [self.is_occupy(agent, l) for l in world.landmarks if l.live]
        if any(occ):
            if agent.last_step:
                agent.occ_steps += 1
                for l in world.landmarks:
                    if self.is_occupy(agent, l) and l.live:
                        if agent.tp == l.tp: rew += agent.occ_steps *2
                        else: rew += agent.occ_steps

            agent.last_step = True
        else:
            agent.last_step = False
            agent.occ_steps = 0

        if agent.occ_steps == 3:
            for l in world.landmarks:
                if self.is_occupy(agent, l) and l.live:
                    agent.live = 0
                    l.live = 0
                    l.color = np.array([0.75, 0.75, 0.75])
                    # if agent.tp == l.tp: agent.ismatch = True

        return rew

    def observation(self, agent, world):
        dist = []
        for i, lm in enumerate(world.landmarks):
            if not lm.live: continue
            dist.append((i, np.sum(np.square(agent.state.p_pos - lm.state.p_pos))))
        dist = sorted(dist, key = lambda t: t[1])
        entity_pos = []
        entity_tp = []
        entity_liv = []
        for i, lm_dist in dist:
            if lm_dist <= agent.sight:
                entity_pos.append(world.landmarks[i].state.p_pos - agent.state.p_pos)
                entity_tp.append(world.landmarks[i].tp_hot)
                # entity_liv.append([world.landmarks[i].live])
        for i in range(self.n_food - len(entity_pos)):
            entity_pos.append([0,0])
            entity_tp.append(np.zeros(3))
            # entity_liv.append([0])

        dist = []
        for i, other in enumerate(world.agents):
            if other is agent: continue
            dist.append((i, np.sum(np.square(agent.state.p_pos - other.state.p_pos))))
        dist = sorted(dist, key = lambda t: t[1])
        other_pos = []
        other_vel = []
        other_id = []
        other_tp = []
        other_liv = []
        for i, other_dist in dist:
            if other_dist <= agent.sight:
                other_pos.append(world.agents[i].state.p_pos - agent.state.p_pos)
                other_vel.append(world.agents[i].state.p_vel)
                # other_id.append(world.agents[i].id_hot)
                other_tp.append(world.agents[i].tp_hot)
                # other_liv.append([world.agents[i].live])
            else:
                other_pos.append([0,0])
                other_vel.append([0,0])
                # other_id.append(np.zeros(self.num_agents))
                other_tp.append(np.zeros(3))
                # other_liv.append([0])


        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + [agent.tp_hot] +\
            entity_tp + entity_pos + other_tp + other_pos + other_vel)