import equinox as eqx
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.scipy.stats import multivariate_normal


class LDSState(eqx.Module):
    """
    Represents a single particle hypothesis (Physics View).
    Used by Particle Filters to simulate individual trajectories.
    """

    x: jax.Array  # State vector (ndim,)
    k: jax.Array  # Mode index (), int32


class MixtureBeliefState(eqx.Module):
    """
    Represents the analytical Ground Truth of the environment (God View).
    Used by the Training Loop to generate optimal labels.
    """

    means: jax.Array  # Shape: (K, ndim)
    covs: jax.Array  # Shape: (K, ndim, ndim)
    log_weights: jax.Array  # Shape: (K,) - Inactive modes will have -inf


class GenerativeBeliefEnvironment(eqx.Module):
    ndim: int
    obs_dim: int
    num_components: int
    active_subset_size: int  # Number of "True" modes per episode

    # System Dynamics Matrices (The "Physics Engine")
    A: jax.Array  # (K, ndim, ndim)
    Q: jax.Array  # (K, ndim, ndim)
    C: jax.Array  # (K, obs_dim, ndim)
    R: jax.Array  # (K, obs_dim, obs_dim)

    # Initialization Parameters
    mu_init: jax.Array  # (K, ndim)
    scale_init: jax.Array  # (K, ndim)

    def __init__(
        self, ndim: int, obs_dim: int, num_components: int, active_subset_size: int, key: jax.Array
    ):
        self.ndim = ndim
        self.obs_dim = obs_dim
        self.num_components = num_components
        self.active_subset_size = active_subset_size

        keys = jax.random.split(key, 5)

        # 1. Initialize Stable Dynamics (Eigenvalues < 1.0)
        self.A = self._init_stable_dynamics(keys[0])
        self.C = jax.random.normal(keys[1], (num_components, obs_dim, ndim))

        # 2. Hard Difficulty Settings (High Drift, Low Observation Noise)
        # High Q: Particles diffuse quickly, making tracking harder
        # Low R: Likelihood peaks are sharp; "wrong" particles get zero weight immediately
        self.Q = jnp.tile(jnp.eye(ndim), (num_components, 1, 1)) * 1.0
        self.R = jnp.tile(jnp.eye(obs_dim), (num_components, 1, 1)) * 0.1

        # 3. Initialize Starting Clusters (Spread far apart to prevent accidental overlap)
        self.mu_init = jax.random.normal(keys[2], (num_components, ndim)) * 10.0
        self.scale_init = jnp.ones((num_components, ndim))

    # ==========================================
    #  PART 1: GENERATIVE API (God View / Training)
    # ==========================================

    def reset(self, key: jax.Array) -> tuple[MixtureBeliefState, jax.Array]:
        """
        1. Selects a random SUBSET of components to be 'active' (The "True" Physics).
        2. Inactivates the rest (log_weight = -inf).
        3. Samples y0 from the active subset.

        Returns: (Analytical Posterior, Observation y0)
        """
        k_subset, k_obs, k_sample = jax.random.split(key, 3)

        # --- A. Sparse Initialization Logic ---
        # 1. Randomly permute indices to select the active subset
        perm = jax.random.permutation(k_subset, self.num_components)
        is_active = jnp.arange(self.num_components) < self.active_subset_size
        # Unshuffle to create the mask for original indices
        mask = is_active[jnp.argsort(perm)]

        # 2. Assign weights: Uniform for active, -inf for inactive
        log_w_active = -jnp.log(self.active_subset_size)
        log_weights = jnp.where(mask, log_w_active, -jnp.inf)

        # 3. Initialize Priors
        means = self.mu_init
        covs = jax.vmap(lambda s: jnp.diag(s**2))(self.scale_init)

        # --- B. Sample Observation y0 ---
        # Only sample from ACTIVE components
        k_probs = jnp.exp(log_weights)
        k_active = jax.random.choice(k_obs, jnp.arange(self.num_components), p=k_probs)

        mu_k = means[k_active]
        cov_k = covs[k_active]
        C_k = self.C[k_active]
        R_k = self.R[k_active]

        y_mean = C_k @ mu_k
        y_cov = C_k @ cov_k @ C_k.T + R_k
        y_cov = 0.5 * (y_cov + y_cov.T) + jnp.eye(self.obs_dim) * 1e-6

        y0 = jax.random.multivariate_normal(k_sample, y_mean, y_cov)

        # --- C. Update to Analytical Posterior ---
        # The filter update will naturally propagate the -inf weights
        posterior_belief = self._mixture_update(means, covs, log_weights, y0)

        return posterior_belief, y0

    def step(
        self,
        belief: MixtureBeliefState,
        key: jax.Array,
    ) -> tuple[MixtureBeliefState, jax.Array]:
        """
        Evolves the ground truth distribution analytically.
        Returns: (Posterior Belief, Observation y)
        """
        k_obs, k_sample = jax.random.split(key)

        # 1. Predict (Analytic Evolution of Gaussians)
        mu_pred, cov_pred = jax.vmap(self._predict_single)(
            belief.means, belief.covs, self.A, self.Q
        )

        # 2. Sample Observation y (Stochastic)
        # Sample from the current belief weights (which respects the sparse subset)
        k_probs = jnp.exp(belief.log_weights)
        k_active = jax.random.choice(k_obs, jnp.arange(self.num_components), p=k_probs)

        mu_k = mu_pred[k_active]
        cov_k = cov_pred[k_active]
        C_k = self.C[k_active]
        R_k = self.R[k_active]

        y_mean = C_k @ mu_k
        y_cov = C_k @ cov_k @ C_k.T + R_k
        y_cov = 0.5 * (y_cov + y_cov.T) + jnp.eye(self.obs_dim) * 1e-6

        y = jax.random.multivariate_normal(k_sample, y_mean, y_cov)

        # 3. Update (Analytic Correction)
        new_belief = self._mixture_update(mu_pred, cov_pred, belief.log_weights, y)

        return new_belief, y

    def sample_from_belief(
        self,
        belief: MixtureBeliefState,
        num_particles: int,
        key: jax.Array,
    ) -> LDSState:
        """
        **NBF Training Helper**:
        Samples 'Ground Truth' particles (x, k) from the current analytical belief.
        Use this to generate the regression targets for your neural filter.

        Returns:
            LDSState: A batch of particles containing both continuous state 'x'
                      and discrete mode 'k'.
        """
        k_choice, k_val = jax.random.split(key)

        # 1. Sample the component indices 'k' based on belief weights
        probs = jnp.exp(belief.log_weights)
        indices = jax.random.choice(
            k_choice, jnp.arange(self.num_components), shape=(num_particles,), p=probs
        )

        # 2. Sample continuous states 'x' given the chosen 'k'
        def _sample(idx, rng):
            return jax.random.multivariate_normal(rng, belief.means[idx], belief.covs[idx])

        keys = jax.random.split(k_val, num_particles)
        x_samples = jax.vmap(_sample)(indices, keys)

        # 3. Return as a full LDSState (compatible with propagate_particle)
        return LDSState(x=x_samples, k=indices)

    # ==========================================
    #  PART 2: PHYSICS API (Agent View / Filtering)
    # ==========================================

    def sample_initial_particles(
        self,
        num_particles: int,
        key: jax.Array,
    ) -> LDSState:
        """
        Samples initial particles (x, k) from the *Global* Environment Prior.

        CRITICAL FOR DIFFICULTY:
        This samples from ALL components (uniform prior), not just the active subset.
        The Particle Filter does not know which subset is active yet.
        """
        k_choice, k_val = jax.random.split(key)

        # 1. Sample latent modes k (Uniform across ALL components)
        probs = jnp.ones(self.num_components) / self.num_components
        k_indices = jax.random.choice(
            k_choice, jnp.arange(self.num_components), shape=(num_particles,), p=probs
        )

        # 2. Sample continuous states x given k
        def _sample(idx, rng):
            mu = self.mu_init[idx]
            scale = self.scale_init[idx]
            noise = jax.random.normal(rng, (self.ndim,))
            return mu + scale * noise

        keys = jax.random.split(k_val, num_particles)
        x_samples = jax.vmap(_sample)(k_indices, keys)

        return LDSState(x=x_samples, k=k_indices)

    def propagate_particle(
        self,
        particle: LDSState,
        key: jax.Array,
    ) -> LDSState:
        """
        Simulates the physics for a single particle hypothesis.
        """
        k = particle.k

        # Evolve x: x' = A_k x + noise
        noise = jax.random.multivariate_normal(key, jnp.zeros(self.ndim), self.Q[k])
        x_new = self.A[k] @ particle.x + noise

        return LDSState(x=x_new, k=k)

    def compute_log_likelihood(
        self, particle: LDSState, belief: MixtureBeliefState, y: jax.Array
    ) -> jax.Array:
        """
        Computes the log likelihood of y given x.

        Note for Baselines:
        A standard PF usually doesn't have access to 'belief' (the true weights).
        If you want to run a baseline PF, you might need a simplified version
        that just evaluates P(y | x, k_particle).
        """
        # 1. Compute predicted means for ALL components
        means = self.C @ particle.x

        # 2. Compute log likelihood for each component: log P(y | x, k)
        def _get_component_log_prob(mu, cov):
            return multivariate_normal.logpdf(y, mu, cov)

        log_probs_y_given_x_k = jax.vmap(_get_component_log_prob)(means, self.R)

        # 3. Marginalize
        return logsumexp(log_probs_y_given_x_k + belief.log_weights)

    # ==========================================
    #  PART 3: INTERNAL HELPERS
    # ==========================================

    def _predict_single(self, mean, cov, A, Q):
        mu_p = A @ mean
        cov_p = A @ cov @ A.T + Q
        cov_p = 0.5 * (cov_p + cov_p.T)
        return mu_p, cov_p

    def _mixture_update(self, mu_pred, cov_pred, log_weights, y):
        new_means, new_covs, log_liks = jax.vmap(
            self._kalman_update_single, in_axes=(0, 0, 0, 0, None)
        )(mu_pred, cov_pred, self.C, self.R, y)

        unnorm_weights = log_weights + log_liks
        log_evidence = logsumexp(unnorm_weights)
        new_log_weights = unnorm_weights - log_evidence

        return MixtureBeliefState(new_means, new_covs, new_log_weights)

    def _kalman_update_single(self, mean, cov, C, R, y):
        y_pred = C @ mean
        S = C @ cov @ C.T + R
        S = 0.5 * (S + S.T) + jnp.eye(S.shape[0]) * 1e-6

        kt = jax.scipy.linalg.solve(S, C @ cov, assume_a='pos')
        K = kt.T

        mu_new = mean + K @ (y - y_pred)

        I = jnp.eye(mean.shape[0])
        diff = I - K @ C
        cov_new = diff @ cov @ diff.T + K @ R @ K.T
        cov_new = 0.5 * (cov_new + cov_new.T)

        log_lik = multivariate_normal.logpdf(y, y_pred, S)
        return mu_new, cov_new, log_lik

    def _init_stable_dynamics(self, key: jax.Array) -> jax.Array:
        keys = jax.random.split(key, self.num_components)

        def _make(k):
            mat = jax.random.normal(k, (self.ndim, self.ndim))
            u, s, vt = jnp.linalg.svd(mat)
            s = jnp.clip(s, 0, 0.95)
            return u @ jnp.diag(s) @ vt

        return jax.vmap(_make)(keys)
