import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
import os
import random
SIGHT = 0.5


class Scenario(BaseScenario):
    def __init__(self):
        self.n_good = 2
        self.n_adv = 4
        self.n_landmarks = 0
        self.n_food = 0
        self.n_forests = 0
        self.num_agents = self.n_adv + self.n_good
        self.max_neighbor = 3
        self.alpha = 0
        self.good_neigh_dist = 100.0
        self.adv_neigh_dist = 100.0
        self.ratio = 1
        self.size = self.ratio
        self.no_wheel = 0

    def make_world(self):
        world = World()
        world.size = self.ratio
        world.good_neigh_dist = self.good_neigh_dist
        world.adv_neigh_dist = self.adv_neigh_dist
        world.max_neighbor = self.max_neighbor
        world.neigh_comm = 0.5
        world.sight = 1.0
        # set any world properties first
        world.dim_c = 2
        num_good_agents = self.n_good
        num_adversaries = self.n_adv
        world.num_good_agents = num_good_agents
        world.num_adversaries = num_adversaries
        num_agents = num_adversaries + num_good_agents
        num_landmarks = self.n_landmarks
        num_food = self.n_food
        # add agents
        world.agents = [Agent() for i in range(num_agents)]
        world.goods =[]
        world.advs =[]
        ID_a = np.eye(num_agents)

        for i, agent in enumerate(world.agents):
            agent.name = 'agent %d' % i
            agent.id = i
            agent.id_hot = ID_a[i]
            agent.collide = True
            agent.silent = True
            agent.adversary = True if i < num_adversaries else False
            agent.sight = world.sight if agent.adversary else 0.6
            if agent.adversary:
                world.advs.append(agent)
            else:
                world.goods.append(agent)
            agent.tp = 0 if agent.adversary else 9
            agent.size = 0.06 if agent.adversary else 0.06
            agent.accel = 4 if agent.adversary else 6
            agent.max_speed = 3 if agent.adversary else 5
            agent.live = 1 if agent.adversary else 1
            agent.occupy = 0

        # make initial conditions
        world.food = [Landmark() for i in range(num_food)]
        for i, landmark in enumerate(world.food):
            landmark.sight = 1
            landmark.name = 'food %d' % i
            landmark.collide = False
            landmark.movable = False
            landmark.size = 0.03
            landmark.boundary = False

        world.landmarks =  world.food

        # neighbors
        self.reset_world(world)
        return world


    def reset_world(self, world):
        for i, agent in enumerate(world.agents):
            agent.color = np.array([0.95, 0.45, 0.45]) if not agent.adversary else np.array([0.1, 0.45, 0.1])
            if i==0 or i==1: agent.color = np.array([0.05, 0.45, 0.55]) #blue
            agent.live = 1 if agent.adversary else 1
            agent.last_step = False
            agent.occ_steps = 0
            agent.col_steps = 0
        for i, landmark in enumerate(world.landmarks):
            landmark.color = np.array([0.25, 0.25, 0.25])
        # set random initial states
        for i, agent in enumerate(world.agents):
            agent.state.p_pos = np.random.uniform(-0.9*self.ratio, +0.9*self.ratio, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.neib = np.eye(self.num_agents)[i].astype(int)
        for i, landmark in enumerate(world.landmarks):
            landmark.state.p_pos = np.random.uniform(-1*0.6, 1*0.6, world.dim_p)
            landmark.state.p_vel = np.zeros(world.dim_p)
        self.occ_steps = 0

    def benchmark_data(self, agent, world):
        return 0
    
    def done(self, agent, world):
        # return False
        return False if agent.live else True


    def is_inrange(self, agent1, tar):
        delta_pos = agent1.state.p_pos - tar.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = 0.2
        return True if dist < dist_min else False
    
    def is_collision(self, agent1, agent2):
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = agent1.size + agent2.size
        return True if dist <= dist_min else False


    def neighb(self, agent, world):
        if not agent.adversary:
            return agent.neib
        else:
            dists = []
            num_neighbor = 0
            for a in world.agents:
                if a is agent: continue
                if not a.adversary: continue
                if not a.live: continue
                dists.append((a.id, np.sum(np.square(agent.state.p_pos - a.state.p_pos))))
            for other_id, other_dist in dists:
                if other_dist <= world.neigh_comm and num_neighbor < world.max_neighbor-1:
                    agent.neib[other_id] = 1
                    num_neighbor += 1
            return agent.neib


    def reward(self, agent, world):
        # Agents are rewarded based on minimum agent distance to each landmark
        # main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)
        rew = 0
        
        if agent.live:
            if not agent.adversary: # Good agents
                # if np.abs(agent.state.p_pos[0])>=1 or np.abs(agent.state.p_pos[1])>=1:
                #     rew -= 15

                # for adv_agnt in world.advs:
                #     for good_agnt in world.goods:
                #         if self.is_collision(adv_agnt, good_agnt) and agent is good_agnt:
                #             rew -= 2

                occ = 0
                for a in world.advs:
                    if self.is_inrange(agent, a) and (a.id==0 or a.id==1):
                        occ += 1
                if occ > 0:
                    if agent.last_step:
                        agent.occ_steps += 1
                    agent.last_step = True
                else:
                    agent.last_step = False
                    agent.occ_steps = 0

                if agent.occ_steps > 5: agent.live = 0

                
                distances = [np.sqrt(np.sum(np.square(agent.state.p_pos - other_agent.state.p_pos))) for other_agent in world.advs if other_agent.live>0]
                if (len(distances) > 0):
                    rew += np.mean(distances)*1


            if agent.adversary:
                
                if agent.id==0 or agent.id==1:
                    if agent.id == 0: other = world.agents[1]
                    if agent.id == 1: other = world.agents[0]
                    if (self.is_inrange(agent, world.goods[0]) and world.goods[0].live and not self.is_inrange(other, world.goods[0])) or \
                        (self.is_inrange(agent, world.goods[1]) and world.goods[1].live and not self.is_inrange(other, world.goods[1])):
                        if agent.last_step:
                            agent.occ_steps += 1
                            rew += agent.occ_steps
                        agent.last_step = True
                    else:
                        agent.last_step = False
                        agent.occ_steps = 0

                else:
                    col = 0
                    col_o = 0
                    for g in world.goods:
                        for a in world.advs:
                            if self.is_inrange(a, g):
                                if a is agent: col += 1
                                else: col_o += 1
                    if self.is_inrange(agent, world.goods[0]) or self.is_inrange(agent, world.goods[1]):
                        if col > col_o and agent.last_step:
                            agent.col_steps += 1
                        agent.last_step = True    
                    else:
                        agent.last_step = False
                        agent.col_steps = 0

                    if agent.col_steps > 5: 
                        agent.live = 0


                distances = [np.sqrt(np.sum(np.square(agent.state.p_pos - other_agent.state.p_pos))) for other_agent in world.agents if not other_agent.adversary and other_agent.live>0]
                if (len(distances) > 0):
                    rew -= min(distances)*0.2

        return rew


    def observation(self, agent, world):
        return self.adv_obs(agent, world) if agent.adversary else self.good_obs(agent, world)

    def adv_obs(self, agent, world):
        other_pos = []
        other_vel = []
        other_liv = []
        other_id = []
        for i, other in enumerate(world.agents):
            if other is agent: continue
            other_pos.append(world.agents[i].state.p_pos - agent.state.p_pos)
            other_vel.append(world.agents[i].state.p_vel)
            other_liv.append(np.array([other.live]))
            other_id.append(world.agents[i].id_hot)

        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + [np.array([agent.live])] +\
            other_id + other_liv + other_pos + other_vel)


    def good_obs(self, agent, world):
        dist = []
        for i, other in enumerate(world.advs):
            if other is agent: continue
            dist.append((i, np.sum(np.square(agent.state.p_pos - other.state.p_pos))))
        dist = sorted(dist, key = lambda t: t[1])
        other_pos = []
        other_vel = []
        other_liv = []
        other_id = []
        for i, other_dist in dist:
            if other_dist <= agent.sight:
                other_pos.append(world.agents[i].state.p_pos - agent.state.p_pos)
                other_vel.append(world.agents[i].state.p_vel)
                other_liv.append(np.array([other.live]))
                other_id.append(world.agents[i].id_hot)
            else:
                other_pos.append([0,0])
                other_vel.append([0,0])
                other_liv.append([0])
                other_id.append(np.zeros(self.num_agents))

        for i, other in enumerate(world.goods):
            if other is agent: continue
            other_pos.append(world.agents[i].state.p_pos - agent.state.p_pos)
            other_vel.append(world.agents[i].state.p_vel)
            other_liv.append(np.array([other.live]))
            other_id.append(world.agents[i].id_hot)
        
        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + [np.array([agent.live])] +\
            other_id + other_liv + other_pos + other_vel)