from typing import Callable

import jax
import jax.numpy as jnp
import jax.scipy.special as jsp
import matplotlib.pyplot as plt
import optax
from sim.distributions import Independent, Mixture, MultivariateNormal, Uniform


def _init_gmm_params(key, num_components: int, dim: int):
    """
    Initialize the parameters of a Gaussian Mixture Model.
    """
    subkeys = jax.random.split(key, 3)
    w = jax.random.normal(subkeys[0], (num_components,))
    mu = jax.random.normal(subkeys[1], (num_components, dim))
    log_sigma = jax.random.normal(subkeys[2], (num_components, dim))
    return {"w": w, "mu": mu, "log_sigma": log_sigma}


def _log_basis(x: jnp.ndarray, params: dict):
    sigma = jnp.exp(params["log_sigma"])
    diff = x[:, None, :] - params["mu"][None, :, :]
    log_prob = -0.5 * jnp.sum((diff / sigma) ** 2, axis=-1)
    return log_prob


def _log_squared_subtractive_mixture(x: jnp.ndarray, params: dict, eps: float = 1e-10):
    log_basis = _log_basis(x, params)
    log_abs_w = jnp.log(jnp.maximum(jnp.abs(params["w"]), eps))  # shape: [K]
    log_term = log_abs_w + log_basis  # shape: [N, K]
    sign_w = jnp.sign(params["w"])  # shape: [K]
    sign_term = jnp.broadcast_to(sign_w, log_term.shape)  # shape: [N, K]

    pos_mask = sign_term > 0
    neg_mask = sign_term < 0
    log_term_pos = jnp.where(pos_mask, log_term, -jnp.inf)
    log_term_neg = jnp.where(neg_mask, log_term, -jnp.inf)

    log_sum_pos = jsp.logsumexp(log_term_pos, axis=-1)  # shape: [N]
    log_sum_neg = jsp.logsumexp(log_term_neg, axis=-1)  # shape: [N]

    log_sum_max = jnp.maximum(log_sum_pos, log_sum_neg)
    log_sum_min = jnp.minimum(log_sum_pos, log_sum_neg)

    log_sum = log_sum_max + jnp.log(
        jnp.maximum(1 - jnp.abs(jnp.exp(log_sum_min - log_sum_max)), eps)
    )
    return 2.0 * log_sum


def fit_squared_subtractive_mixture(
    key: jax.random.PRNGKey,
    log_target: Callable,
    data_sampler: Callable,
    num_components: int,
    dim: int,
    num_iters: int = 100,
    learning_rate: float = 0.01,
    batch_size: int = 1000,
):
    optimizer = optax.adam(learning_rate=learning_rate)
    params = _init_gmm_params(key, num_components, dim)
    opt_state = optimizer.init(params)

    def _loss_fn(params: dict, x: jnp.ndarray):
        log_prob = _log_squared_subtractive_mixture(x, params)
        gt = log_target(x)
        return jnp.mean((log_prob - gt) ** 2)

    @jax.jit
    def train_step(params: dict, opt_state: optax.OptState, x: jnp.ndarray):
        loss, grads = jax.value_and_grad(_loss_fn)(params, x)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return loss, params, opt_state

    for i in range(num_iters):
        key, subkey = jax.random.split(key)
        x = data_sampler(subkey, batch_size)
        loss, params, opt_state = train_step(params, opt_state, x)
        if i % 1000 == 0:
            print(f"Iteration {i}, loss: {loss}")

    return params


if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    p = Independent(
        Uniform(low=jnp.array([0.0, -2.0]), high=jnp.array([2.0, 2.0])),
        reinterpreted_batch_ndims=1,
    )
    q = Mixture(
        mixing_probs=jnp.array([0.5, 0.5]),
        component_distributions=[
            MultivariateNormal(
                loc=jnp.array([0.5, -1.0]),
                cov=jnp.array([[0.06, 0.01], [0.01, 0.06]]),
            ),
            MultivariateNormal(
                loc=jnp.array([1.3, 0.5]),
                cov=jnp.array([[0.06, 0.01], [0.01, 0.06]]),
            ),
        ],
    )
    log_target = lambda x: q.log_prob(x) - p.log_prob(x)
    data_sampler = lambda key, batch_size: p.sample(key, (batch_size,))
    params = fit_squared_subtractive_mixture(
        key=key,
        log_target=log_target,
        data_sampler=data_sampler,
        num_components=10,
        dim=2,
        num_iters=50000,
        learning_rate=0.01,
        batch_size=1000,
    )

    linspace = jnp.linspace(-2, 2, 100)
    X1, X2 = jnp.meshgrid(linspace, linspace)
    X = jnp.stack([X1.ravel(), X2.ravel()], axis=-1)
    Y_pred = jnp.exp(_log_squared_subtractive_mixture(X, params)).reshape(100, 100)

    Y_true = jnp.exp(log_target(X)).reshape(100, 100)

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.contourf(X1, X2, Y_pred)
    plt.colorbar()
    plt.title("Predicted")
    plt.subplot(1, 2, 2)
    plt.contourf(X1, X2, Y_true)
    plt.colorbar()
    plt.title("True")
    plt.show()
