import jax
import jax.numpy as jnp
import equinox as eqx
from jax.scipy.special import kl_div

import smarter_jax as sj
from distribution_embedding import FlowEmbedding


def resample_model(env: eqx.Module, sj_key: sj.Subkeys, model: FlowEmbedding,
                   z: jax.Array, nsamples: int = 256):
    model_samples = model.generate(z, nsamples=nsamples, key=next(sj_key))
    return jnp.clip(model_samples, 0, env.size - 1)


def compute_approximate_beliefs(key: jax.random.PRNGKey,env: eqx.Module, model: eqx.Module,
                                   z: jax.Array,repeats: int = 20) -> jax.Array:

    def _log_prob(k: jax.random.PRNGKey, x: jax.Array) -> jax.Array:
        return model.log_prob(x, z, key=k)

    coords = jnp.stack(jnp.meshgrid(*[jnp.arange(env.size)] * env.ndim, indexing="ij"),axis=-1)
    inputs = coords.reshape((-1, env.ndim))
    keys = jax.random.split(key, repeats)
    res = jax.vmap(_log_prob, in_axes=(0, None))(keys, inputs)
    res = jax.nn.logsumexp(res, axis=0) - jnp.log(repeats)
    flattened_beliefs = jax.nn.softmax(res, axis=0)
    return flattened_beliefs.reshape((env.size,) * env.ndim)

def compute_empirical_dist(samples: jax.Array, grid_size: int) -> jax.Array:
    flat_indices = samples[:, 0] * grid_size + samples[:, 1]
    freq = jnp.bincount(flat_indices, minlength=grid_size * grid_size)
    dist = freq.reshape((grid_size, grid_size))
    return dist / jnp.sum(dist)

def top_k_correct(dist: jax.Array, k: int, true_pos: jax.Array) -> bool:
    _, idxs = jax.lax.top_k(dist.ravel(), k)
    flat_true_pos = jnp.ravel_multi_index(true_pos, dist.shape, mode='clip')
    return jnp.any(idxs == flat_true_pos)

def js_divergence(p: jax.Array, q: jax.Array) -> float:
    m = 0.5 * (p + q)
    eps = 1e-12
    p_c = jnp.clip(p, eps, 1.0)
    q_c = jnp.clip(q, eps, 1.0)
    m_c = jnp.clip(m, eps, 1.0)
    return jnp.sum(0.5 * kl_div(p_c, m_c) + 0.5 * kl_div(q_c, m_c))

def nll(dist: jax.Array, sample: jax.Array) -> float:
    return -jnp.log(dist.ravel()[jnp.ravel_multi_index(sample, dist.shape, mode='clip')])

def base_dir_string(exp_args):
    s = f'g_{exp_args.grid_size}_{exp_args.ndim}_{exp_args.max_walk_depth}_'
    s += f'c_{exp_args.ncubes}_{exp_args.cube_width}_p_{exp_args.policy_temp}'
    return s