import numpy as np
from core import World, Agent, Landmark
from scenario import BaseScenario


import traceback
class Scenario(BaseScenario):
    def make_world(self, args=None):
        world = World()
        # set any world properties first
        world.dim_c = 2
        num_drones = 9
        num_agents = num_drones # deactivate "good" agents
        num_landmarks = num_drones
        # add agents
        self.agent_queues = [[],[],[]]
        self.landmark_queues = [[],[],[]]
        world.agents = [Agent() for i in range(num_agents)]
        for i, agent in enumerate(world.agents):
            self.agent_queues[i % 3].append(agent)
            agent.idx_queue = i % 3
            agent.name = 'agent %d' % i
            agent.collide = True
            agent.silent = True
            agent.size = 0.075
            agent.accel = 1.1
            agent.max_speed = 1.0
            agent.fuel = 1.0
            agent.carrying = 1.
            agent.action_callback = None
            agent.view_radius = getattr(args, "agent_view_radius", -1)
            print("AGENT VIEW RADIUS set to: {}".format(agent.view_radius))
        # add landmarks

        world.landmarks = [Landmark() for i in range(num_landmarks)]
        for i, landmark in enumerate(world.landmarks):
            landmark.name = 'landmark %d' % i
            landmark.idx_queue = i % 3
            self.landmark_queues[i%3].append(landmark)
            landmark.collide = True
            landmark.movable = False
            landmark.size = 0.5
            landmark.boundary = False
        # make initial conditions
        self.reset_world(world)
        self.score_function= getattr(args, "score_function", "sum")
        return world

    def get_idx_queue(self, landmark):
        if np.max(np.abs(landmark.state.p_pos)) < 1.905:
            return 0
        elif np.max(np.abs(landmark.state.p_pos)) < 2.69:
            return 1
        else:
            return 2

    def reset_world(self, world):
        # random properties for agents
        for i, agent in enumerate(world.agents):
            agent.color = np.array([0.35, 0.85, 0.35])
            # random properties for landmarks
        for i, landmark in enumerate(world.landmarks):
            landmark.color = np.array([0.25, 0.25, 0.25])
        # set random initial states
        for i, agent in enumerate(world.agents):
            agent.state.p_pos = np.random.uniform(-3.3, +3.3, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.state.c = np.zeros(world.dim_c)
            agent.timeout = 0
            agent.fuel = 1.
            agent.fuel_empties = 0.
            idx_queue = i % 3
            agent.delivery_targets = list(np.random.randint(0, len(self.landmark_queues[idx_queue]), size=2))

        for i, landmark in enumerate(world.landmarks):
            idx_queue = i % 3
            if not landmark.boundary:
                landmark.state.p_vel = np.zeros(world.dim_p)
                if idx_queue == 0:
                    landmark.state.p_pos = np.random.uniform(-1.905, +1.905, world.dim_p)
                elif idx_queue == 1:
                    landmark.state.p_pos = np.random.uniform(-2.69, +2.69, world.dim_p)
                    while self.get_idx_queue(landmark) != 1:
                        landmark.state.p_pos = np.random.uniform(-2.69, +2.69, world.dim_p)
                elif idx_queue == 2:
                    landmark.state.p_pos = np.random.uniform(-3.3, +3.3, world.dim_p)
                    while self.get_idx_queue(landmark) != 2:
                        landmark.state.p_pos = np.random.uniform(-3.3, +3.3, world.dim_p)
                else:
                    raise
                dists = []
                
                for j, other in enumerate(self.agent_queues[idx_queue]):
                    dist = np.sqrt(np.sum(np.square(other.state.p_pos - landmark.state.p_pos)))
                    dists.append(dist)
                landmark.avg_dist = np.min(dists)


    def benchmark_data(self, agent, world):
        # returns data for benchmarking purposes
        return 0

    def is_collision(self, agent1, agent2):
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = agent1.size + agent2.size
        return True if dist < dist_min else False
     
    def reward(self, agent, world):
        # Agents are rewarded based on minimum agent distance to each landmark
        reward = 0



        #if np.random.randint(0, 1000) == 0:
        #    traceback.print_stack()
        agent.fuel -= np.linalg.norm(agent.state.p_vel)**2 * (0.020)

        for i, ago in enumerate(world.agents):
            if ago.name != agent.name and self.is_collision(agent, ago) and agent.timeout == 0:
                #if np.random.randint(0, 50) == 0:
                #print('collision, %s, %d' % (agent.name, i))
                agent.movable = True
                agent.state.p_vel = np.zeros(world.dim_p)
                ago.state.p_vel = np.zeros(world.dim_p)
                delta_pos = ago.state.p_pos - agent.state.p_pos

                p1 = ago.state.p_pos
                p2 = agent.state.p_pos
                ago.state.p_pos = ago.state.p_pos + np.random.uniform(1.0 * ago.size, 2.5 * ago.size, world.dim_p) * np.sign(delta_pos)
                agent.state.p_pos = agent.state.p_pos + np.random.uniform(1.0 * agent.size, 2.5*agent.size, world.dim_p) * (-1 * np.sign(delta_pos))
                #agent.state.p_pos = np.random.uniform(-3.3, +3.3, world.dim_p)
                
                #print('coll a1: %s, a2: %s, d1: %s, aa1:%s, aa2: %s, ad1: %s' % (str(p1), str(p2), str(np.linalg.norm(p1-p2)), str(ago.state.p_pos), str(agent.state.p_pos), str(np.linalg.norm(agent.state.p_pos - ago.state.p_pos))))
                #agent.timeout = 100
                reward -= 8.

        if agent.fuel <= 0. and agent.timeout == 0:
            #if np.random.randint(0, 500) == 0:
            #print('fuel empty, %s' % (agent.name))
            #agent.timeout = 100
            agent.movable = True
            reward -= 8.
            agent.fuel = 1.
            agent.fuel_empties+=1.
            agent.state.p_vel = np.zeros(world.dim_p)

        for idx_queue, lmqueue in enumerate(self.landmark_queues):
            for i, landmark in enumerate(lmqueue):
                if idx_queue == agent.idx_queue and self.is_collision(agent, landmark) and (i in agent.delivery_targets):
                    agent.carrying = 0.

                    dist_travelled = landmark.avg_dist
                    dist_travelled = max(dist_travelled, 0.1)

                    max_v = (1.0 / dist_travelled) / (0.020*10)
                    max_v = min(max_v, agent.max_speed)
                    ticks = (dist_travelled * 10 / max_v)
                    rew = 4 + (0.6 * ticks * (1+dist_travelled*0.4))

                    expected_empties = 0.1 + 1.0 / max_v

                    if agent.fuel_empties > 0.:
                        used_fuel = 1.0 - agent.fuel
                        reward += rew / max(used_fuel + agent.fuel_empties, expected_empties)
                    else:
                        reward += rew

                    #print('deliver, %s, %d, %f' % (agent.name, i, rew))
                    agent.fuel = 1.
                    agent.fuel_empties = 0.

                    agent.delivery_targets = list(np.random.randint(0, len(self.landmark_queues[idx_queue]), size=2))
                    
                    if idx_queue == 0:
                        landmark.state.p_pos = np.random.uniform(-1.905, +1.905, world.dim_p)
                    elif idx_queue == 1:
                        landmark.state.p_pos = np.random.uniform(-2.69, +2.69, world.dim_p)
                        while self.get_idx_queue(landmark) != 1:
                            landmark.state.p_pos = np.random.uniform(-2.69, +2.69, world.dim_p)
                    elif idx_queue == 2:
                        landmark.state.p_pos = np.random.uniform(-3.3, +3.3, world.dim_p)
                        while self.get_idx_queue(landmark) != 2:
                            landmark.state.p_pos = np.random.uniform(-3.3, +3.3, world.dim_p)
                    else:
                        raise

                    dists = []
                    for j, other in enumerate(self.agent_queues[idx_queue]):
                        dist = np.sqrt(np.sum(np.square(other.state.p_pos - landmark.state.p_pos)))
                        dists.append(dist)
                    landmark.avg_dist = np.min(dists)
        return reward

    def observation(self, agent, world):
        # get positions of all entities in this agent's reference frame
    

        #set movable or not:

        #if np.random.randint(0, 1000) == 0:
        #    traceback.print_stack()
        '''
        if agent.timeout == 1:
            agent.timeout -= 1
            #need to randomly position now and reset myself
            agent.movable = True
            agent.state.p_pos = np.random.uniform(-1.8, +1.8, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.fuel = 1.

        elif agent.timeout > 0:
            agent.timeout -= 1
        '''
        #print('before, fuel, vel %f %s' % (agent.fuel, str(agent.state.p_vel)))



        #if np.random.randint(0, 1000) == 0:
        #    fbef = agent.fuel
        #    faft = fbef - np.linalg.norm(agent.state.p_vel)**2 * (0.023)
        #    vel = np.linalg.norm(agent.state.p_vel)
        #    print('fuel bef, aft, vel %f, %f, %f' % (fbef, faft, vel))

        #print('after, fuel %f' % (agent.fuel))
#looks good

        entity_pos = []
        entity_dist = []
        for entity in world.landmarks:
            dist = np.sqrt(np.sum(np.square(entity.state.p_pos - agent.state.p_pos)))
            entity_dist.append(dist)
            if not entity.boundary and (agent.view_radius >= 0) and dist <= agent.view_radius:
                entity_pos.append(entity.state.p_pos - agent.state.p_pos)
            else:
                entity_pos.append(np.array([0., 0.]))

        distargsort = np.argsort(entity_dist)

        entity_pos = [entity_pos[i] for i in distargsort]
        entity_pos = entity_pos[0:2]
        # communication of all other agents
        comm = []
        prey_pos = []
        prey_dist = []
        prey_vel = []
        for other in world.agents:
            dist = np.sqrt(np.sum(np.square(other.state.p_pos - agent.state.p_pos)))
            prey_pos.append(other.state.p_pos - agent.state.p_pos)
            prey_dist.append(dist)
            prey_vel.append(other.state.p_vel)
        
        distargsort = np.argsort(prey_dist)
        prey_pos = [prey_pos[i] for i in distargsort]
        prey_vel = [prey_vel[i] for i in distargsort]
        prey_pos = prey_pos[1:2]
        prey_vel = prey_vel[1:2]

        target_pos = []
        for tgi, tg in enumerate(agent.delivery_targets):
            target_pos.append(self.landmark_queues[agent.idx_queue][tg].state.p_pos - agent.state.p_pos)
            target_pos.append([self.landmark_queues[agent.idx_queue][tg].avg_dist])
        #print(target_pos)

        target_pos = np.concatenate(target_pos)
        #print(target_pos)
        #print(agent.state.p_vel)
        #print(agent.state.p_pos)
        #print(prey_pos)
        #print(prey_vel)
        #print(target_pos)
        #out = np.concatenate([agent.state.p_vel, agent.state.p_pos, prey_pos[0], prey_vel[0], [agent.carrying, agent.fuel], target_pos])
        #print(out)
        return np.concatenate([agent.state.p_vel, agent.state.p_pos, prey_pos[0], prey_vel[0], [agent.fuel], target_pos])



    def full_observation(self, agent, world):
        return self.observation(agent, world)
