import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import PyTree

from envs.gridworld import GridState
from flows.distribution_embedding import FlowEmbedding
from gridworld.rnn_belief_model import RNNBeliefModel
from gridworld.utils import compute_approximate_beliefs


class ParticleSet(eqx.Module):
    particles: eqx.Module
    weights: jax.Array

    def simulate(
        self, key: jax.Array, env: eqx.Module, policy: PyTree, obs: PyTree
    ) -> 'ParticleSet':
        sim_keys = jax.random.split(key, self.weights.shape[0])
        actions = jax.vmap(policy)(sim_keys, self.particles)
        aw = jax.vmap(policy.action_probs)(self.particles)[
            jnp.arange(self.weights.shape[0]), actions
        ]
        state, sim_obs = jax.vmap(env.step, in_axes=(0, 0))(self.particles, actions)
        eq = jax.vmap(eqx.tree_equal, in_axes=(0, None))(sim_obs, obs)
        w = jnp.where(eq, self.weights * aw, 0)
        w = jax.lax.cond(
            jnp.sum(w) > 0,
            lambda w: w / jnp.sum(w),
            lambda _: jnp.ones(self.weights.shape[0]) / self.weights.shape[0],
            w,
        )
        return ParticleSet(particles=state, weights=w)

    def neff(self) -> float:
        return 1.0 / jnp.sum(jnp.square(self.weights) + 1e-5)

    def simple_resample(self, key: jax.Array) -> 'ParticleSet':
        N = self.particles.agent_position.shape[0]
        csum = jnp.cumsum(self.weights).at[-1].set(1.0)
        indices = jnp.searchsorted(csum, jax.random.uniform(key, shape=(N,)))
        particles = GridState(self.particles.agent_position[indices], self.particles.walls[indices])
        return ParticleSet(particles=particles, weights=jnp.ones(N) / N)

    def systematic_resample(self, key: jax.Array) -> 'ParticleSet':
        N = self.weights.shape[0]
        w = self.weights / jnp.sum(self.weights)
        cdf = jnp.cumsum(w)
        u0 = jax.random.uniform(key, ()) / N
        positions = u0 + jnp.arange(N, dtype=w.dtype) / N
        idx = jnp.searchsorted(cdf, positions, side='left')
        particles = jax.tree_util.tree_map(lambda x: x[idx], self.particles)
        return ParticleSet(particles=particles, weights=jnp.ones(N) / N)


class BaseFilter(eqx.Module):
    env: eqx.Module

    def reset(self, key: jax.Array, initial_beliefs: jax.Array, state: eqx.Module) -> jax.Array:
        return initial_beliefs

    @eqx.filter_jit
    def update(
        self,
        key: jax.Array,
        beliefs: jax.Array,
        state: eqx.Module,
        policy: PyTree,
        obs: jax.Array,
    ) -> jax.Array:
        return self.env.update_beliefs(beliefs, state, policy, obs)

    @eqx.filter_jit
    def sample(self, key: jax.Array, beliefs: jax.Array, nsamples: int) -> jax.Array:
        return self.env.sample_beliefs(key, beliefs, nsamples)

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.Array, beliefs: jax.Array) -> jax.Array:
        return beliefs


class ParticleFilter(BaseFilter):
    nparticles: int

    def reset(self, key: jax.Array, init_beliefs: jax.Array, state: eqx.Module) -> ParticleSet:
        particles = GridState(
            self.env.sample_beliefs(key, init_beliefs, self.nparticles),
            jnp.expand_dims(state.walls, axis=0).repeat(self.nparticles, axis=0),
        )
        weights = jnp.ones(self.nparticles) / self.nparticles
        return ParticleSet(particles=particles, weights=weights)

    @eqx.filter_jit
    def update(
        self,
        key: jax.Array,
        particle_set: ParticleSet,
        state: eqx.Module,
        policy: PyTree,
        obs: jax.Array,
    ) -> ParticleSet:
        sim_key, resample_key = jax.random.split(key, 2)
        particle_set = particle_set.simulate(sim_key, self.env, policy, obs)
        return jax.lax.cond(
            particle_set.neff() < self.nparticles / 2,
            lambda: particle_set.systematic_resample(resample_key),
            lambda: particle_set,
        )

    @eqx.filter_jit
    def sample(self, key: jax.Array, particle_set: ParticleSet, nsamples: int) -> jax.Array:
        p = particle_set.weights / (jnp.sum(particle_set.weights) + 1e-5)
        return jax.random.choice(key, particle_set.particles.agent_position, shape=(nsamples,), p=p)

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.Array, particle_set: ParticleSet) -> jax.Array:
        b_flat = jnp.zeros((self.env.size**self.env.ndim,))
        idx = jnp.ravel_multi_index(
            particle_set.particles.agent_position.T,
            dims=(self.env.size,) * self.env.ndim,
            mode='clip',
        )
        b_flat = b_flat.at[idx].add(particle_set.weights)
        b = b_flat.reshape((self.env.size,) * self.env.ndim)
        return b / jnp.sum(b)


class GroundTruthFilter(BaseFilter):
    pass


class NeuralGTFilter(BaseFilter):
    model: FlowEmbedding
    nparticles: int

    def _get_embedding(self, key, beliefs):
        samples = self.env.sample_beliefs(key, beliefs, self.nparticles)
        return self.model.embed(samples)[0]

    @eqx.filter_jit
    def sample(self, key: jax.Array, beliefs: jax.Array, nsamples: int) -> jax.Array:
        gen_key, output_key = jax.random.split(key, 2)
        z = self._get_embedding(gen_key, beliefs)
        return self.model.generate(z, nsamples=nsamples, key=output_key)

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.Array, beliefs: jax.Array) -> jax.Array:
        gen_key, output_key = jax.random.split(key, 2)
        z = self._get_embedding(gen_key, beliefs)
        return compute_approximate_beliefs(output_key, self.env, self.model, z)


class NeuralBayesFilter(BaseFilter):
    model: FlowEmbedding
    nparticles: int
    log_prob_repeats: int = 20

    @eqx.filter_jit
    def reset(self, key: jax.Array, initial_beliefs: jax.Array, state: eqx.Module) -> jax.Array:
        samples = self.env.sample_beliefs(key, initial_beliefs, self.nparticles)
        return self.model.embed(samples)[0]

    @eqx.filter_jit
    def sample(self, key: jax.Array, z: jax.Array, nsamples: int) -> jax.Array:
        x = self.model.generate(z, nsamples=nsamples, key=key).astype(jnp.int32)
        x = jnp.clip(x, 0, self.env.size - 1)
        return x

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.Array, z: jax.Array) -> jax.Array:
        return compute_approximate_beliefs(key, self.env, self.model, z)

    def _proposal_weights(self, key: jax.Array, z: jax.Array, x: jax.Array) -> jax.Array:
        def _log_prob(key, inputs):
            return self.model.log_prob(inputs, z, key=key)[0]

        log_keys = jax.random.split(key, self.log_prob_repeats)
        res = jax.vmap(_log_prob, in_axes=(0, None))(log_keys, x)
        return jnp.exp(jax.nn.logsumexp(res, axis=0) - jnp.log(self.log_prob_repeats))

    @eqx.filter_jit
    def update(
        self,
        key: jax.Array,
        z: jax.Array,
        state: eqx.Module,
        policy: PyTree,
        obs: jax.Array,
    ) -> jax.Array:
        gen_key, sim_key = jax.random.split(key, 2)
        out = self.sample(gen_key, z, self.nparticles)
        particles = GridState(
            out,
            jnp.expand_dims(state.walls, axis=0).repeat(self.nparticles, axis=0),
        )
        w = jnp.ones(self.nparticles)
        particle_set = ParticleSet(particles=particles, weights=w / self.nparticles)
        particle_set = particle_set.simulate(sim_key, self.env, policy, obs)
        return self.model.embed(
            particle_set.particles.agent_position, weights=particle_set.weights
        )[0]


class RNNFilter(BaseFilter):
    model: RNNBeliefModel
    obs_shape: int

    @eqx.filter_jit
    def reset(self, key: jax.Array, initial_beliefs: jax.Array, state: eqx.Module) -> jax.Array:
        return self.model.reset(self.obs_shape)

    @eqx.filter_jit
    def update(
        self,
        key: jax.Array,
        h: jax.Array,
        state: eqx.Module,
        policy: PyTree,
        obs: jax.Array,
    ) -> jax.Array:
        return self.model.update(jnp.expand_dims(obs.wall, axis=-1), jnp.squeeze(h))[-1]

    @eqx.filter_jit
    def sample(self, key: jax.Array, h: jax.Array, nsamples: int) -> jax.Array:
        preds = self.model.predict(h)
        index = jax.random.choice(key, jnp.arange(preds.shape[0]), shape=(nsamples,), p=preds)
        return jnp.unravel_index(index, (self.env.size,) * self.env.ndim)

    @eqx.filter_jit
    def compute_beliefs(self, key: jax.Array, h: jax.Array) -> jax.Array:
        out = self.model.predict(jnp.squeeze(h))
        return jnp.reshape(out, (self.env.size,) * self.env.ndim)
