#!/usr/bin/env python3

from __future__ import annotations

import equinox as eqx
import jax
import jax.numpy as jnp

from envs.triangulation import Triangulation, TriangulationObservation, TriangulationState
from flows.distribution_embedding import FlowEmbedding


class ParticleSet(eqx.Module):
    particles: TriangulationState
    weights: jax.Array

    @eqx.filter_jit
    def simulate(
        self, env: Triangulation, action: int, observation: TriangulationObservation, key: jax.Array
    ) -> ParticleSet:
        sim_keys = jax.random.split(key, self.weights.shape[0])
        particles, timesteps = jax.vmap(env.step, in_axes=(0, None, 0))(
            self.particles, action, sim_keys
        )
        obs_weights = env.observation_weights(particles, timesteps.observation, observation)
        new_weights = self.weights * obs_weights
        weight_sum = jnp.sum(new_weights)
        weights = jax.lax.cond(
            weight_sum > 0.0,
            lambda: new_weights / weight_sum,
            lambda: obs_weights,
        )
        return ParticleSet(particles, weights)

    @eqx.filter_jit
    def neff(self) -> float:
        denominator = jnp.sum(self.weights**2)
        is_defined = (denominator > 0.0) & jnp.isfinite(denominator)

        return jax.lax.cond(is_defined, lambda: 1.0 / denominator, lambda: 0.0)

    @eqx.filter_jit
    def systematic_resample(self, key: jax.Array) -> ParticleSet:
        num_particles = self.weights.shape[0]

        cdf = jnp.cumsum(self.weights)
        x0 = jax.random.uniform(key, (), minval=0.0, maxval=1.0 / num_particles)
        positions = x0 + jnp.arange(num_particles) / num_particles
        indices = jnp.searchsorted(cdf, positions, side='left')

        particles = jax.tree.map(lambda x: x[indices], self.particles)
        weights = jnp.full(num_particles, 1 / num_particles)

        return ParticleSet(particles, weights)


class ParticleFilter(eqx.Module):
    """Implement the standard SIR filter with threshold-based resampling."""

    num_particles: int
    resample_factor: float = 0.5

    def __repr__(self) -> str:
        return f'PF ({self.num_particles})'

    @eqx.filter_jit
    def sample(
        self, particle_set: ParticleSet, state: TriangulationState, num_samples: int, key: jax.Array
    ) -> TriangulationState:
        probs = particle_set.weights / jnp.sum(particle_set.weights)
        indices = jax.random.choice(key, self.num_particles, (num_samples,), p=probs)

        return jax.tree.map(lambda x: x[indices], particle_set.particles)

    @eqx.filter_jit
    def update(
        self,
        env: Triangulation,
        particle_set: ParticleSet,
        state: TriangulationState,
        observation: TriangulationObservation,
        action: int,
        next_observation: TriangulationObservation,
        key: jax.Array,
    ) -> tuple[ParticleSet, float]:
        key, simulate_key = jax.random.split(key, 2)
        particle_set = particle_set.simulate(env, action, next_observation, simulate_key)

        key, resample_key = jax.random.split(key, 2)
        particle_set = jax.lax.cond(
            particle_set.neff() < self.num_particles * self.resample_factor,
            lambda: particle_set.systematic_resample(resample_key),
            lambda: particle_set,
        )

        return particle_set, particle_set.neff()

    @eqx.filter_jit
    def reset(
        self,
        env: Triangulation,
        state: TriangulationState,
        observation: TriangulationObservation,
        key: jax.Array,
    ) -> ParticleSet:
        particles, timesteps = jax.vmap(env.reset)(jax.random.split(key, self.num_particles))

        features = jax.vmap(lambda p: p.to_flat_features())(particles)
        particles = jax.vmap(lambda f: state.from_flat_features(f))(features)
        weights = env.observation_weights(particles, timesteps.observation, observation)

        return ParticleSet(particles, weights)


class NeuralBayesFilter(eqx.Module):
    model: FlowEmbedding
    num_particles: int

    def __repr__(self) -> str:
        return f'NBF ({self.num_particles})'

    @eqx.filter_jit
    def sample(
        self, embedding: jax.Array, state: TriangulationState, num_samples: int, key: jax.Array
    ) -> TriangulationState:
        return jax.vmap(state.from_flat_features)(
            self.model.generate(embedding, num_samples, key=key)
        )

    @eqx.filter_jit
    def compare_state_obs(
        self,
        env: Triangulation,
        particles: TriangulationState,
        observation: TriangulationObservation,
    ) -> jax.Array:
        observations = jax.vmap(lambda s: s.observation_from_state())(particles)
        obs_weights = env.observation_weights(particles, observations, observation)

        return jax.lax.cond(
            env.has_state_observation,
            lambda: obs_weights,
            lambda: jnp.full(self.num_particles, 1 / self.num_particles),
        )

    @eqx.filter_jit
    def step(
        self,
        env: Triangulation,
        embedding: jax.Array,
        state: TriangulationState,
        observation: TriangulationObservation,
        action: int,
        next_observation: TriangulationObservation,
        key: jax.Array,
    ) -> ParticleSet:
        sample_key, simulate_key = jax.random.split(key, 2)

        particles = self.sample(embedding, state, self.num_particles - 1, sample_key)
        particles = jax.tree.map(
            lambda x, y: jnp.concatenate([x, y[jnp.newaxis]], axis=0), particles, state
        )
        weights = self.compare_state_obs(env, particles, observation)
        particle_set = ParticleSet(particles, weights)

        return particle_set.simulate(env, action, next_observation, simulate_key)

    @eqx.filter_jit
    def update(
        self,
        env: Triangulation,
        embedding: jax.Array,
        state: TriangulationState,
        observation: TriangulationObservation,
        action: int,
        next_observation: TriangulationObservation,
        key: jax.Array,
    ) -> jax.Array:
        # Generate a new gold state from the embedding so we can add it to the particle set
        # for the update without cheating. `update_with_ground_truth` puts the gold state in
        # directly, which is useful during search when we have already have a guess of the
        # gold state that matches the observation sequence in the tree.

        key, sample_key = jax.random.split(key, 2)
        state = jax.tree.map(lambda x: x[0], self.sample(embedding, state, 1, sample_key))

        particle_set = self.step(env, embedding, state, observation, action, next_observation, key)
        features = jax.vmap(lambda p: p.to_flat_features())(particle_set.particles)

        return self.model.embed(features, particle_set.weights)[0], particle_set.neff()

    @eqx.filter_jit
    def update_with_ground_truth(
        self,
        env: Triangulation,
        embedding: jax.Array,
        state: TriangulationState,
        observation: TriangulationObservation,
        action: int,
        next_observation: TriangulationObservation,
        key: jax.Array,
    ) -> jax.Array:
        # Same as `update()`, except we don't replace the gold state with a sample. This helps
        # because we already have a state from simulation that matches the observation sequence
        # in the tree. If the model is good, this only adds bias and we can revisit calling it.

        particle_set = self.step(env, embedding, state, observation, action, next_observation, key)
        features = jax.vmap(lambda p: p.to_flat_features())(particle_set.particles)

        return self.model.embed(features, particle_set.weights)[0], particle_set.neff()

    @eqx.filter_jit
    def reset(
        self,
        env: Triangulation,
        state: TriangulationState,
        observation: TriangulationObservation,
        key: jax.Array,
    ) -> jax.Array:
        particles, timesteps = jax.vmap(env.reset)(jax.random.split(key, self.num_particles))

        features = jax.vmap(lambda p: p.to_flat_features())(particles)
        particles = jax.vmap(lambda f: state.from_flat_features(f))(features)
        weights = env.observation_weights(particles, timesteps.observation, observation)

        return self.model.embed(features, weights)[0]
