import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import PyTree

from distribution_embedding import FlowEmbedding
from belief_model import RNNBeliefModel
from pomdp.utils import compute_approximate_beliefs
from pomdp.gridworld_jax import GridState


class ParticleSet(eqx.Module):
    particles: eqx.Module
    weights: jax.Array

    def simulate(self, key: jax.random.PRNGKey, env: eqx.Module, policy: PyTree, obs: PyTree) -> 'ParticleSet':
        sim_keys = jax.random.split(key, self.weights.shape[0])
        actions = jax.vmap(policy)(sim_keys, self.particles)
        aw = jax.vmap(policy.action_probs)(self.particles)[jnp.arange(self.weights.shape[0]), actions]
        state, sim_obs = jax.vmap(env.step, in_axes=(0, 0))(self.particles, actions)
        eq = jax.vmap(eqx.tree_equal, in_axes=(0, None))(sim_obs, obs)
        w = jnp.where(eq, self.weights * aw, 0)
        w = jax.lax.cond(jnp.sum(w) > 0,
                         lambda w: w / jnp.sum(w),
                         lambda _: jnp.ones(self.weights.shape[0]) / self.weights.shape[0],
                         w)
        return ParticleSet(particles=state, weights=w)

    def neff(self) -> float:
        return 1. / jnp.sum(jnp.square(self.weights) + 1e-5)

    def simple_resample(self, key: jax.random.PRNGKey) -> 'ParticleSet':
        N = self.particles.agent_pos.shape[0]
        csum = jnp.cumsum(self.weights).at[-1].set(1.)
        indices = jnp.searchsorted(csum, jax.random.uniform(key, shape=(N,)))
        particles = GridState(agent_pos=self.particles.agent_pos[indices],
                              walls=self.particles.walls[indices])
        return ParticleSet(particles=particles, weights=jnp.ones(N) / N)

    def systematic_resample(self, key: jax.random.PRNGKey) -> "ParticleSet":
        N = self.weights.shape[0]
        w = self.weights / jnp.sum(self.weights)
        cdf = jnp.cumsum(w)
        u0 = jax.random.uniform(key, ()) / N
        positions = u0 + jnp.arange(N, dtype=w.dtype) / N
        idx = jnp.searchsorted(cdf, positions, side="left")
        particles = jax.tree_util.tree_map(lambda x: x[idx], self.particles)
        return ParticleSet(particles=particles, weights=jnp.ones(N) / N)

# JIT-compiling filter methods because the caller is typically not JIT-compiled
class BaseFilter(eqx.Module):
    env: eqx.Module

    def reset(self, key: jax.random.PRNGKey, initial_beliefs: jax.Array, state: eqx.Module) -> jax.Array:
        return initial_beliefs

    @eqx.filter_jit
    def update(self, key: jax.random.PRNGKey, beliefs: jax.Array, state: eqx.Module,
               policy: PyTree, obs: jax.Array) -> jax.Array:
        return self.env.update_beliefs(beliefs, state, policy, obs)

    @eqx.filter_jit
    def sample(self, key: jax.random.PRNGKey, beliefs: jax.Array, nsamples: int) -> jax.Array:
        return self.env.sample_beliefs(key, beliefs, nsamples)

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.random.PRNGKey, beliefs: jax.Array) -> jax.Array:
        return beliefs


class ParticleFilter(BaseFilter):
    nparticles: int

    def reset(self, key: jax.random.PRNGKey, init_beliefs: jax.Array,state: eqx.Module) -> ParticleSet:
        particles = GridState(agent_pos=self.env.sample_beliefs(key, init_beliefs, self.nparticles),
                              walls=jnp.expand_dims(state.walls, axis=0).repeat(self.nparticles, axis=0))
        weights = jnp.ones(self.nparticles) / self.nparticles
        return ParticleSet(particles=particles, weights=weights)

    @eqx.filter_jit
    def update(self, key: jax.random.PRNGKey,  particle_set: ParticleSet, state: eqx.Module,
               policy: PyTree, obs: jax.Array) -> ParticleSet:
        sim_key, resample_key = jax.random.split(key, 2)
        particle_set = particle_set.simulate(sim_key, self.env, policy, obs)
        return jax.lax.cond(particle_set.neff() < self.nparticles / 2,
                            lambda: particle_set.systematic_resample(resample_key),
                            lambda: particle_set)

    @eqx.filter_jit
    def sample(self, key: jax.random.PRNGKey, particle_set: ParticleSet, nsamples: int) -> jax.Array:
        p = particle_set.weights / (jnp.sum(particle_set.weights) + 1e-5)
        return jax.random.choice(key, particle_set.particles.agent_pos, shape=(nsamples,), p=p)

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.random.PRNGKey, particle_set: ParticleSet) -> jax.Array:
        b_flat = jnp.zeros((self.env.size ** self.env.ndim,))
        idx = jnp.ravel_multi_index(particle_set.particles.agent_pos.T,
                                    dims=(self.env.size,) * self.env.ndim, mode='clip') 
        b_flat = b_flat.at[idx].add(particle_set.weights)
        b = b_flat.reshape((self.env.size,) * self.env.ndim)
        return b / jnp.sum(b)


class NeuralGTFilter(BaseFilter):
    model: FlowEmbedding
    nparticles: int

    def _get_embedding(self, key, beliefs):
        samples = self.env.sample_beliefs(key, beliefs, self.nparticles)
        return self.model.embed(samples)

    @eqx.filter_jit
    def sample(self, key: jax.random.PRNGKey, beliefs: jax.Array, nsamples: int) -> jax.Array:
        gen_key, output_key = jax.random.split(key, 2)
        z = self._get_embedding(gen_key, beliefs)
        return self.model.generate(z, nsamples=nsamples, key=output_key)

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.random.PRNGKey, beliefs: jax.Array) -> jax.Array:
        gen_key, output_key = jax.random.split(key, 2)
        z = self._get_embedding(gen_key, beliefs)
        return compute_approximate_beliefs(output_key, self.env, self.model, z)


class NeuralBayesFilter(BaseFilter):
    model: FlowEmbedding
    nparticles: int
    log_prob_repeats: int = 20

    @eqx.filter_jit
    def reset(self, key: jax.random.PRNGKey, initial_beliefs: jax.Array, state: eqx.Module) -> jax.Array:
        samples = self.env.sample_beliefs(key, initial_beliefs, self.nparticles)
        return self.model.embed(samples)

    @eqx.filter_jit
    def sample(self, key: jax.random.PRNGKey, z: jax.Array, nsamples: int) -> jax.Array:
        x = self.model.generate(z, nsamples=nsamples, key=key).astype(jnp.int32)
        x = jnp.clip(x, 0, self.env.size - 1)
        return x

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.random.PRNGKey, z: jax.Array) -> jax.Array:
        return compute_approximate_beliefs(key, self.env, self.model, z)

    def _proposal_weights(self, key: jax.random.PRNGKey, z: jax.Array, x: jax.Array) -> jax.Array:
        def _log_prob(key, inputs):
            return self.model.log_prob(inputs, z, key=key)

        log_keys = jax.random.split(key, self.log_prob_repeats)
        res = jax.vmap(_log_prob, in_axes=(0, None))(log_keys, x)
        return jnp.exp(jax.nn.logsumexp(res, axis=0) - jnp.log(self.log_prob_repeats))

    @eqx.filter_jit
    def update(self, key: jax.random.PRNGKey, z: jax.Array, state: eqx.Module,
               policy: PyTree, obs: jax.Array) -> jax.Array:
        gen_key, sim_key = jax.random.split(key, 2)
        out = self.sample(gen_key, z, self.nparticles)
        particles = GridState(agent_pos=out,
                              walls=jnp.expand_dims(state.walls, axis=0).repeat(self.nparticles, axis=0))
        w = jnp.ones(self.nparticles)
        particle_set = ParticleSet(particles=particles, weights=w / self.nparticles)
        particle_set = particle_set.simulate(sim_key, self.env, policy, obs)
        return self.model.embed(particle_set.particles.agent_pos, weights=particle_set.weights)


class RNNFilter(BaseFilter):
    model: RNNBeliefModel
    obs_shape: int

    @eqx.filter_jit
    def reset(self, key: jax.random.PRNGKey, initial_beliefs: jax.Array, state: eqx.Module) -> jax.Array:
        return self.model.reset(self.obs_shape)

    @eqx.filter_jit
    def update(self, key: jax.random.PRNGKey, h: jax.Array, state: eqx.Module, policy: PyTree,
               obs: jax.Array) -> jax.Array:
        return self.model.update(jnp.expand_dims(obs.wall, axis=-1), jnp.squeeze(h))[-1]

    @eqx.filter_jit
    def sample(self, key: jax.random.PRNGKey, h: jax.Array, nsamples: int) -> jax.Array:
        preds = self.model.predict(h)
        index = jax.random.choice(key, jnp.arange(preds.shape[0]), shape=(nsamples,), p=preds)
        return jnp.unravel_index(index, (self.env.size,) * self.env.ndim)

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.random.PRNGKey, h: jax.Array) -> jax.Array:
        out = self.model.predict(jnp.squeeze(h))
        return jnp.reshape(out, (self.env.size,) * self.env.ndim)


GroundTruthFilter = BaseFilter