from typing import Callable

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import matplotlib.pyplot as plt
import optax

from priorg.sim.distributions import Mixture, MultivariateNormal


def init_gmm_params(key: jnp.ndarray, num_components: int, dim: int) -> dict:
    key, key_means, key_log_vars = jax.random.split(key, 3)
    means = jax.random.normal(key_means, (num_components, dim))
    log_stds = jax.random.normal(key_log_vars, (num_components, dim))
    log_weights = jnp.log(jnp.ones(num_components) / num_components)
    return {"log_weights": log_weights, "means": means, "log_stds": log_stds}


# def gmm_log_prob(params: dict, x: jnp.ndarray):
#     weights = params["weights"]
#     means = params["means"]
#     vars = jnp.exp(params["log_vars"])

#     dim = x.shape[1]

#     def _log_prob_component(mu, var):
#         diff = x - mu  # (N, D)
#         quad = jnp.sum((diff**2) / var, axis=1)  # (N,)
#         log_det = jnp.sum(jnp.log(var))  # scalar
#         return -0.5 * (quad + log_det + dim * jnp.log(2 * jnp.pi))

#     comp_log_probs = jax.vmap(_log_prob_component)(means, vars)  # (K, N)
#     comp_log_probs = comp_log_probs.T  # (N, K)
#     log_weights = jnp.log(weights)
#     weighted = comp_log_probs + log_weights  # (N, K)
#     return jsp.special.logsumexp(weighted, axis=1)  # (N,)


def gmm_log_prob(params: dict, x: jnp.ndarray):
    log_weights = params["log_weights"]
    means = params["means"]
    stds = jnp.exp(params["log_stds"])

    log_prob = jax.vmap(
        lambda mean, std: jnp.sum(jsp.stats.norm.logpdf(x, mean, std), axis=-1)
    )(means, stds)
    return jsp.special.logsumexp(log_prob.T + log_weights, axis=-1)


def fit_gmm(
    key: jnp.ndarray,
    log_target: Callable,
    data_sampler: Callable,
    num_components: int,
    dim: int,
    num_iters: int = 10000,
    learning_rate: float = 0.01,
    batch_size: int = 1000,
):
    key, key_init = jax.random.split(key)
    params = init_gmm_params(key=key_init, num_components=num_components, dim=dim)
    optimizer = optax.adam(learning_rate=learning_rate)
    opt_state = optimizer.init(params)

    def loss_fn(params: dict, x: jnp.ndarray):
        gt = log_target(x)
        pred = gmm_log_prob(params, x)
        return jnp.mean((gt - pred) ** 2)

    @jax.jit
    def step(params: dict, opt_state: optax.OptState, x: jnp.ndarray):
        loss, grads = jax.value_and_grad(loss_fn)(params, x)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return loss, params, opt_state

    for i in range(num_iters):
        key, key_batch = jax.random.split(key)
        x = data_sampler(key_batch, batch_size)
        loss, params, opt_state = step(params, opt_state, x)

        if i % 1000 == 0:
            print(f"fitting GMM, iteration {i}, loss: {loss}", flush=True)

    return params


if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    p = MultivariateNormal(loc=jnp.zeros(2), cov=jnp.eye(2))
    # p = Independent(
    #     Uniform(low=jnp.array([0.0, -2.0]), high=jnp.array([2.0, 2.0])),
    #     reinterpreted_batch_ndims=1,
    # )
    q = Mixture(
        mixing_probs=jnp.array([0.5, 0.5]),
        component_distributions=[
            MultivariateNormal(
                loc=jnp.array([0.5, -1.0]),
                cov=jnp.array([[0.06, 0.01], [0.01, 0.06]]),
            ),
            MultivariateNormal(
                loc=jnp.array([1.3, 0.5]),
                cov=jnp.array([[0.06, 0.01], [0.01, 0.06]]),
            ),
        ],
    )
    log_target = lambda x: q.log_prob(x) - p.log_prob(x)
    data_sampler = lambda key, batch_size: p.sample(key, (batch_size,))
    params = fit_gmm(
        key=key,
        log_target=log_target,
        data_sampler=data_sampler,
        num_components=20,
        dim=2,
        num_iters=50000,
        learning_rate=0.01,
        batch_size=1000,
    )

    linspace = jnp.linspace(-10, 10, 100)
    X1, X2 = jnp.meshgrid(jnp.linspace(0, 2, 100), jnp.linspace(-2, 2, 100))
    X = jnp.stack([X1.ravel(), X2.ravel()], axis=-1)
    Y_pred = jnp.exp(gmm_log_prob(params, X)).reshape(100, 100)

    Y_true = jnp.exp(log_target(X)).reshape(100, 100)

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.contourf(X1, X2, Y_pred)
    plt.colorbar()
    plt.title("Predicted")
    plt.subplot(1, 2, 2)
    plt.contourf(X1, X2, Y_true)
    plt.colorbar()
    plt.title("True")
    plt.show()
