import numpy as np
import numba
from numba import jit
from numba import int32, float32, boolean, float64, deferred_type
from numba.experimental import jitclass
import copy

@jitclass([('p_pos', float64[:]), ('p_vel', float64[:])])
class EntityState_nb(object):
    def __init__(self):
        # physical position
        self.p_pos
        # physical velocity
        self.p_vel


EntityState_type = deferred_type()
EntityState_type.define(EntityState_nb.class_type.instance_type)

@jitclass([('size', float32), ('movable', boolean), ('collide', boolean), ('state', EntityState_type)])
class Entity_nb(object):
    def __init__(self):
        # state
        self.state = EntityState_nb()
        # entity collides with others
        self.collide = True
        self.size = 0.050
        self.movable = False



@jit(nopython=True)
def get_collision_force_nb(entity_a, entity_b, contact_margin, contact_force):
    #if (not entity_a.collide) or (not entity_b.collide):
    #    return np.array([0, 0]), np.array([0, 0])  # not a collider
    #if (entity_a is entity_b):
    #    return np.array([0, 0]), np.array([0, 0])  # don't collide against itself
    # compute actual distance between entities
    delta_pos = entity_a.state.p_pos - entity_b.state.p_pos
    dist = np.sqrt(np.sum(np.square(delta_pos)))
    # minimum allowable distance
    dist_min = entity_a.size + entity_b.size
    # softmax penetration
    k = contact_margin
    penetration = np.logaddexp(0, -(dist - dist_min) / k) * k
    force = contact_force * delta_pos / dist * penetration
    force_a = +force if entity_a.movable else np.zeros(force.shape)
    force_b = -force if entity_b.movable else np.zeros(force.shape)
    return force_a, force_b


@jit(nopython=True)
def apply_environment_force_nb(p_force, entities, contact_margin, contact_force):
    for a, entity_a in enumerate(entities):
        for b, entity_b in enumerate(entities):
            if b <= a: continue
            [f_a, f_b] = get_collision_force_nb(entity_a, entity_b, contact_margin, contact_force)
            #if f_a is not None:
                #if p_force[a] is None: p_force[a] = 0.0
            p_force[a] = f_a + p_force[a]
            #if f_b is not None:
            #    if p_force[b] is None: p_force[b] = 0.0
            p_force[b] = f_b + p_force[b]
    return p_force


@jit(nopython=True)
def min_nb(a):
    return np.min(a)


# physical/external base state of all entites
class EntityState(object):
    def __init__(self):
        # physical position
        self.p_pos = None
        # physical velocity
        self.p_vel = None


# state of agents (including communication and internal/mental state)
class AgentState(EntityState):
    def __init__(self):
        super(AgentState, self).__init__()
        # communication utterance
        self.c = None


# action of the agent
class Action(object):
    def __init__(self):
        # physical action
        self.u = None
        # communication action
        self.c = None


# properties and state of physical world entity
class Entity(object):
    def __init__(self):
        # name 
        self.name = ''
        # properties:
        self.size = 0.050
        # entity can move / be pushed
        self.movable = False
        # entity collides with others
        self.collide = True
        # material density (affects mass)
        self.density = 25.0
        # color
        self.color = None
        # max speed and accel
        self.max_speed = None
        self.accel = None
        # state
        self.state = EntityState()
        # mass
        self.initial_mass = 1.0

    @property
    def mass(self):
        return self.initial_mass


# properties of landmark entities
class Landmark(Entity):
    def __init__(self):
        super(Landmark, self).__init__()


# properties of agent entities
class Agent(Entity):
    def __init__(self, action_callback=None):
        super(Agent, self).__init__()
        # agents are movable by default
        self.movable = True
        # cannot send communication signals
        self.silent = False
        # cannot observe the world
        self.blind = False
        # physical motor noise amount
        self.u_noise = None
        # communication noise amount
        self.c_noise = None
        # control range
        self.u_range = 1.0
        # state
        self.state = AgentState()
        # action
        self.action = Action()
        # script behavior to execute
        self.action_callback = action_callback
        self.id = None


# multi-agent world
class World(object):
    def __init__(self, scripted_agents=None, obs_callback=None, use_numba=False):
        # list of agents and entities (can change at execution-time!)
        self.agents = []
        self.landmarks = []
        # communication channel dimensionality
        self.dim_c = 0
        # position dimensionality
        self.dim_p = 2
        # color dimensionality
        self.dim_color = 3
        # simulation timestep
        self.dt = 0.1
        # physical damping
        self.damping = 0.25
        # contact response parameters
        self.contact_force = 1e+2
        self.contact_margin = 1e-3
        self.use_numba = use_numba
        self.dist_min = None
        # scripted agents
        self.s_agents = scripted_agents
        self.obs_callback = obs_callback
        self.force_mask = None

    def get_force_mask(self):
        n_entities = len(self.entities)
        mask = np.ones((n_entities, n_entities, 2))
        for a, entity_a in enumerate(self.entities):
            for b, entity_b in enumerate(self.entities):
                if not entity_a.movable or (entity_a is entity_b) or not entity_a.collide or not entity_b.collide:
                    mask[a, b, :] = 0
        return mask

    # return all entities in the world
    @property
    def entities(self):
        self.dist_min = 2 * self.agents[0].size
        return self.agents + self.landmarks


    # return all agents controllable by external policies
    @property
    def policy_agents(self):
        return [agent for agent in self.agents if agent.action_callback is None]

    # return all agents controlled by world scripts
    @property
    def scripted_agents(self):
        return [agent for agent in self.agents if agent.action_callback is not None]

    # update state of the world
    def step(self):
        # set actions for scripted agents 
        for agent in self.scripted_agents:
            agent.action = agent.action_callback(agent, self)
        # gather forces applied to entities
        p_force = [None] * len(self.entities)
        # apply agent physical controls
        p_force = self.apply_action_force(p_force)
        # apply environment forces
        p_force = self.apply_environment_force_vec(p_force)
        # integrate physical state
        self.integrate_state(p_force)
        # update agent state
        for agent in self.agents:
            self.update_agent_state(agent)

    # gather agent action forces
    def apply_action_force(self, p_force):
        # set applied forces
        for i, agent in enumerate(self.agents):
            if agent.movable:
                noise = np.random.randn(*agent.action.u.shape) * agent.u_noise if agent.u_noise else 0.0
                p_force[i] = agent.action.u + noise
        return p_force

    # gather physical forces acting on entities

    def apply_environment_force(self, p_force):
        # simple (but inefficient) collision response
        for a,entity_a in enumerate(self.entities):
            for b,entity_b in enumerate(self.entities):
                if(b <= a): continue
                [f_a, f_b] = self.get_collision_force(entity_a, entity_b)
                if(f_a is not None):
                    if(p_force[a] is None): p_force[a] = 0.0
                    p_force[a] = f_a + p_force[a] 
                if(f_b is not None):
                    if(p_force[b] is None): p_force[b] = 0.0
                    p_force[b] = f_b + p_force[b]        
        return p_force







    # integrate physical state
    def integrate_state(self, p_force):
        for i, entity in enumerate(self.entities):
            if not entity.movable: continue
            entity.state.p_vel = entity.state.p_vel * (1 - self.damping)
            if (p_force[i] is not None):
                entity.state.p_vel += (p_force[i] / entity.mass) * self.dt
            if entity.max_speed is not None:
                speed = np.sqrt(np.square(entity.state.p_vel[0]) + np.square(entity.state.p_vel[1]))
                if speed > entity.max_speed:
                    entity.state.p_vel = entity.state.p_vel / np.sqrt(np.square(entity.state.p_vel[0]) +
                                                                      np.square(
                                                                          entity.state.p_vel[1])) * entity.max_speed
            entity.state.p_pos += entity.state.p_vel * self.dt

    def update_agent_state(self, agent):
        # set communication state (directly for now)
        if agent.silent:
            agent.state.c = np.zeros(self.dim_c)
        else:
            noise = np.random.randn(*agent.action.c.shape) * agent.c_noise if agent.c_noise else 0.0
            agent.state.c = agent.action.c + noise

            # get collision forces for any contact between two entities

    def get_collision_force(self, entity_a, entity_b):
        if (not entity_a.collide) or (not entity_b.collide):
            return [None, None]  # not a collider
        if (entity_a is entity_b):
            return [None, None]  # don't collide against itself
        # compute actual distance between entities
        delta_pos = entity_a.state.p_pos - entity_b.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        # minimum allowable distance
        dist_min = entity_a.size + entity_b.size
        # softmax penetration
        k = self.contact_margin
        penetration = np.logaddexp(0, -(dist - dist_min) / k) * k
        force = self.contact_force * delta_pos / dist * penetration
        force_a = +force if entity_a.movable else None
        force_b = -force if entity_b.movable else None
        return [force_a, force_b]

    def apply_environment_force_vec(self, p_force):

        if self.force_mask is None:
            self.force_mask = self.get_force_mask()
        n_entities = len(self.entities)
        e_pos = np.array([[e.state.p_pos for e in self.entities]])
        e_pos1 = e_pos.repeat(n_entities, axis=0)
        e_pos1 = np.transpose(e_pos1, axes=(1, 0, 2))
        e_pos2 = e_pos.repeat(n_entities, axis=0)
        delta_pos = e_pos1 - e_pos2
        dist = np.sqrt(np.sum(np.square(delta_pos), axis=2)) + np.eye(n_entities)
        k = self.contact_margin
        penetration = np.logaddexp(0, -(dist - self.dist_min) / k) * k
        force = self.contact_force * delta_pos / dist.reshape(n_entities, n_entities, 1) * penetration.reshape(n_entities, n_entities, 1)

        masked_force = force * self.force_mask
        sum_force = np.sum(masked_force, axis=1)

        for i, f in enumerate(p_force):
            if f is None: p_force[i] = 0
            p_force[i] += sum_force[i]


        """
        for a, entity_a in enumerate(self.entities):
            for b, entity_b in enumerate(self.entities):
                if (b <= a): continue
                #[f_a, f_b] = self.get_collision_force(entity_a, entity_b)
                f_a = force[a, b, :] if entity_a.movable and entity_a is not entity_b and (entity_a.collide and entity_b.collide) else None
                f_b = -force[a, b, :] if entity_b.movable and entity_a is not entity_b and (entity_a.collide and entity_b.collide) else None

                if (f_a is not None):
                    if (p_force[a] is None): p_force[a] = 0.0
                    p_force[a] = f_a + p_force[a]
                if (f_b is not None):
                    if (p_force[b] is None): p_force[b] = 0.0
                    p_force[b] = f_b + p_force[b]
        """

        return p_force