from functools import partial

import jax
from jax import numpy as np
from jaxtyping import PRNGKeyArray, PyTreeDef, Array

from exptax.models.base import BaseExperiment

@partial(jax.jit, static_argnums=(0, 4))
def bounds_eig_fix_shape(
    exp_model: BaseExperiment,
    true_theta: Array,
    hist: PyTreeDef,
    rng_key: PRNGKeyArray,
    n_meas,
    inner_samples=int(1e7),
):
    """
    Compute SPCE and SNMC bounds for EIG
    """
    thetas = jax.tree.map(
        lambda value, dis: dis.sample(rng_key, (inner_samples,)),
        exp_model.ground_truth,
        exp_model.params_distrib,
    )

    def cond_logprob(index, theta, y, xi):
        shpes = jax.eval_shape(exp_model.log_prob, theta, y, xi)
        return jax.lax.cond(
            index < n_meas,
            exp_model.log_prob,
            lambda *_: np.zeros_like(shpes),
            theta,
            y,
            xi,
        )

    indices = np.arange(hist["meas"].shape[0])
    log_liks_c = (
        jax.vmap(cond_logprob, in_axes=(0, None, 0, 0))(
            indices, thetas, hist["meas"], hist["xi"]
        )
        .sum(axis=0)
        .squeeze()
    )
    log_lik_0 = (
        jax.vmap(cond_logprob, in_axes=(0, None, 0, 0))(
            indices, true_theta, hist["meas"], hist["xi"]
        )
        .sum(axis=0)
        .squeeze()
    )
    all_liks = np.hstack([log_liks_c, log_lik_0])
    contr_log_lik = jax.scipy.special.logsumexp(all_liks) - np.log(inner_samples + 1)

    spce = log_lik_0 - contr_log_lik

    contr_log_lik = jax.scipy.special.logsumexp(log_liks_c) - np.log(inner_samples)
    snmc = log_lik_0 - contr_log_lik

    return spce, snmc