import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import kl_div

# def resample_model(
#     env: eqx.Module, key: jax.Array, model: FlowEmbedding, z: jax.Array, nsamples: int = 256
# ):
#     model_samples = model.generate(z, nsamples=nsamples, key=key)
#     return jnp.clip(model_samples, 0, env.size - 1)


@eqx.filter_jit
def compute_approximate_beliefs(
    key: jax.Array, env: eqx.Module, model: eqx.Module, z: jax.Array, repeats: int = 20
) -> jax.Array:
    def _log_prob(k: jax.Array, x: jax.Array) -> jax.Array:
        return model.log_prob(x, z, key=k)[0]

    coords = jnp.stack(jnp.meshgrid(*[jnp.arange(env.size)] * env.ndim, indexing='ij'), axis=-1)
    inputs = coords.reshape((-1, env.ndim))
    keys = jax.random.split(key, repeats)
    res = jax.vmap(_log_prob, in_axes=(0, None))(keys, inputs)
    res = jax.nn.logsumexp(res, axis=0) - jnp.log(repeats)
    flattened_beliefs = jax.nn.softmax(res, axis=0)
    return flattened_beliefs.reshape((env.size,) * env.ndim)


# def compute_empirical_dist(samples: jax.Array, grid_size: int) -> jax.Array:
#     flat_indices = samples[:, 0] * grid_size + samples[:, 1]
#     freq = jnp.bincount(flat_indices, minlength=grid_size * grid_size)
#     dist = freq.reshape((grid_size, grid_size))
#     return dist / jnp.sum(dist)


@eqx.filter_jit
def top_k_accuracy(dist: jax.Array, k: int, true_pos: jax.Array) -> bool:
    _, idxs = jax.lax.top_k(dist.ravel(), k)
    flat_true_pos = jnp.ravel_multi_index(true_pos, dist.shape, mode='clip')
    return jnp.any(idxs == flat_true_pos)


@eqx.filter_jit
def js_divergence(p: jax.Array, q: jax.Array) -> float:
    m = 0.5 * (p + q)
    eps = 1e-12
    p_c = jnp.clip(p, eps, 1.0)
    q_c = jnp.clip(q, eps, 1.0)
    m_c = jnp.clip(m, eps, 1.0)
    return jnp.sum(0.5 * kl_div(p_c, m_c) + 0.5 * kl_div(q_c, m_c))


@eqx.filter_jit
def negative_log_likelihood(dist: jax.Array, sample: jax.Array) -> float:
    return -jnp.log(dist.ravel()[jnp.ravel_multi_index(sample, dist.shape, mode='clip')])


# def create_dir_name(args: argparse.Namespace) -> str:
#     base_name = f'g_{args.grid_size}_{args.ndim}_{args.max_walk_depth}_'
#     base_name += f'c_{args.ncubes}_{args.cube_width}_p_{args.policy_temp}'
#     return base_name


def iqr_filter(
    data: np.ndarray | list[float], lower_factor: float = 1.5, upper_factor: float = 1.5
) -> np.ndarray:
    data = np.asarray(data)

    q1 = np.percentile(data, 25)
    q3 = np.percentile(data, 75)
    iqr = q3 - q1

    lower_bound = q1 - lower_factor * iqr
    upper_bound = q3 + upper_factor * iqr

    return data[(data >= lower_bound) & (data <= upper_bound)]
