import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
import os
import random
SIGHT = 0.5


class Scenario(BaseScenario):
    def __init__(self):
        self.n_good = 2
        self.n_adv = 4
        self.n_landmarks = 0
        self.n_food = 0
        self.n_forests = 0
        self.num_agents = self.n_adv + self.n_good
        self.alpha = 0
        self.good_neigh_dist = 100.0
        self.adv_neigh_dist = 100.0
        self.ratio = 1
        self.size = self.ratio
        self.no_wheel = 0
        self.max_good_neighbor = self.n_good
        self.max_adv_neighbor = self.n_adv

    def make_world(self):
        world = World()
        world.size = self.ratio
        world.good_neigh_dist = self.good_neigh_dist
        world.adv_neigh_dist = self.adv_neigh_dist
        world.max_good_neighbor = self.max_good_neighbor
        world.max_adv_neighbor = self.max_adv_neighbor
        world.neigh_comm = 1.0
        world.sight = 0.8
        # set any world properties first
        world.dim_c = 2
        num_good_agents = self.n_good
        num_adversaries = self.n_adv
        world.num_good_agents = num_good_agents
        world.num_adversaries = num_adversaries
        num_agents = num_adversaries + num_good_agents
        num_landmarks = self.n_landmarks
        num_food = self.n_food
        # add agents
        world.agents = [Agent() for i in range(num_agents)]
        world.goods =[]
        world.advs =[]
        ID_a = np.eye(num_agents)

        for i, agent in enumerate(world.agents):
            agent.name = 'agent %d' % i
            agent.id = i
            agent.id_hot = ID_a[i]
            agent.collide = True
            agent.silent = True
            agent.adversary = True if i < num_adversaries else False
            agent.sight = world.sight if agent.adversary else world.sight
            if agent.adversary:
                world.advs.append(agent)
            else:
                world.goods.append(agent)
            agent.tp = 0 if agent.adversary else 9
            agent.size = 0.06 if agent.adversary else 0.06
            agent.accel = 4 if agent.adversary else 6
            agent.max_speed = 3 if agent.adversary else 5
            agent.live = 1 if agent.adversary else 1
            agent.occupy = 0

        # make initial conditions
        world.food = [Landmark() for i in range(num_food)]
        for i, landmark in enumerate(world.food):
            landmark.sight = 1
            landmark.name = 'food %d' % i
            landmark.collide = False
            landmark.movable = False
            landmark.size = 0.03
            landmark.boundary = False

        world.landmarks =  world.food

        # neighbors
        self.reset_world(world)
        return world


    def reset_world(self, world):
        for i, agent in enumerate(world.agents):
            agent.color = [0.45, 0.45, 0.95] if not agent.adversary else np.array([0.95, 0.45, 0.45])
            # if i == 4: agent.color = np.array([0.1, 0.45, 0.1])  #green
            agent.live = 1 if agent.adversary else 1
            agent.occupy = 0
            agent.be_collide = 0
        for i, landmark in enumerate(world.landmarks):
            landmark.color = np.array([0.25, 0.25, 0.25])
        # set random initial states
        for i, agent in enumerate(world.agents):
            agent.state.p_pos = np.random.uniform(-0.9*self.ratio, +0.9*self.ratio, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.neib = np.eye(self.num_agents)[i].astype(int)
        for i, landmark in enumerate(world.landmarks):
            landmark.state.p_pos = np.random.uniform(-1*0.6, 1*0.6, world.dim_p)
            landmark.state.p_vel = np.zeros(world.dim_p)
        self.occ_steps = 0

    def benchmark_data(self, agent, world):
        return np.concatenate([np.array(time_grass)]+[np.array(time_live)])
    
    def done(self, agent, world):
        return False

    def is_inrange(self, agent1, tar):
        delta_pos = agent1.state.p_pos - tar.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = tar.sight
        return True if dist < dist_min else False
    
    def is_collision(self, agent1, agent2):
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = agent1.size + agent2.size
        return True if dist <= dist_min else False


    def neighb(self, agent, world):
        if not agent.adversary:
            return agent.neib
        else:
            dists = []
            num_neighbor = 0
            for a in world.agents:
                if a is agent: continue
                if not a.adversary: continue
                dists.append((a.id, np.sum(np.square(agent.state.p_pos - a.state.p_pos))))
            for other_id, other_dist in dists:
                if other_dist <= world.neigh_comm:
                    agent.neib[other_id] = 1
                    num_neighbor += 1
            return agent.neib


    def reward(self, agent, world):
        # Agents are rewarded based on minimum agent distance to each landmark
        # main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
        rew = 0
        
        if agent.live:
            if not agent.adversary: # Good agents
                if np.abs(agent.state.p_pos[0])>=1 or np.abs(agent.state.p_pos[1])>=1:
                    rew -= 15

                for adv_agnt in world.advs:
                    for good_agnt in world.goods:
                        if self.is_collision(adv_agnt, good_agnt) and agent is good_agnt:
                            rew -= 2

                distances = [np.sqrt(np.sum(np.square(agent.state.p_pos - other_agent.state.p_pos))) for other_agent in world.advs if other_agent.live>0]
                if (len(distances) > 0):
                    rew += np.mean(distances)*1

            if agent.adversary:
                for adv_agnt in world.advs:
                    for good_agnt in world.goods:
                        if self.is_collision(adv_agnt, good_agnt) and agent is adv_agnt:
                            rew += 1

                distances = [np.sqrt(np.sum(np.square(agent.state.p_pos - other_agent.state.p_pos))) for other_agent in world.agents if not other_agent.adversary and other_agent.live>0]
                if (len(distances) > 0):
                    rew -= min(distances)*1

        return rew


    def observation(self, agent, world):
        return self.adv_obs(agent, world) if agent.adversary else self.good_obs(agent, world)

    def adv_obs(self, agent, world):
        dist = []
        for i, other in enumerate(world.agents):
            if other is agent: continue
            dist.append((i, np.sum(np.square(agent.state.p_pos - other.state.p_pos))))
        dist = sorted(dist, key = lambda t: t[1])
        other_pos = []
        other_vel = []
        other_liv = []
        other_id = []
        for i, other_dist in dist:
            if other_dist <= agent.sight:
                other_pos.append(world.agents[i].state.p_pos - agent.state.p_pos)
                other_vel.append(world.agents[i].state.p_vel)
                other_liv.append(np.array([other.live]))
                other_id.append(world.agents[i].id_hot)
            else:
                other_pos.append([0,0])
                other_vel.append([0,0])
                other_liv.append([0])
                other_id.append(np.zeros(self.num_agents))

        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + [np.array([agent.live])] +\
            other_id + other_liv + other_pos + other_vel)
    

    def good_obs(self, agent, world):
        dist = []
        for i, other in enumerate(world.advs):
            if other is agent: continue
            dist.append((i, np.sum(np.square(agent.state.p_pos - other.state.p_pos))))
        dist = sorted(dist, key = lambda t: t[1])
        other_pos = []
        other_vel = []
        other_liv = []
        other_id = []
        for i, other_dist in dist:
            if other_dist <= agent.sight:
                other_pos.append(world.agents[i].state.p_pos - agent.state.p_pos)
                other_vel.append(world.agents[i].state.p_vel)
                other_liv.append(np.array([other.live]))
                other_id.append(world.agents[i].id_hot)
            else:
                other_pos.append([0,0])
                other_vel.append([0,0])
                other_liv.append([0])
                other_id.append(np.zeros(self.num_agents))

        for i, other in enumerate(world.goods):
            if other is agent: continue
            other_pos.append(world.agents[i].state.p_pos - agent.state.p_pos)
            other_vel.append(world.agents[i].state.p_vel)
            other_liv.append(np.array([other.live]))
            other_id.append(world.agents[i].id_hot)
        
        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + [np.array([agent.live])] +\
            other_id + other_liv + other_pos + other_vel)