import os
import math
import jax
import matplotlib.pyplot as plt
import seaborn as sns
import hydra
import numpy as np
import jax.numpy as jnp
from jax import random

from dist import GaussMixJax
from util import diffusion_sampler
from util import generate_inverse_problem_gm, sliced_wasserstein
from numpyro.distributions import MultivariateNormal, Categorical, MixtureSameFamily
from nuts import sample_nuts


@hydra.main(version_base=None, config_path="configs/", config_name="gmm")
def main(cfg):
    example_cfg = cfg.example
    key = jax.random.PRNGKey(cfg.random_key)
    n_groups = cfg.n_groups # each group is individual
    n_particles = example_cfg.n_particles
    n_steps = example_cfg.n_steps
    T = example_cfg.T
    dim = example_cfg.dim
    dim_y = example_cfg.dim_y
    kappa_range = example_cfg.kappa_range
    snr_range = example_cfg.snr_range
    noise = cfg.noise
    n_samples = cfg.n_samples
    dts = jnp.ones(n_steps) * T / n_steps

    key, subkey = random.split(key)
    center = []
    scale = 1
    for i in range(-2, 3):
        center += [jnp.array([-8.*scale*i, -8.*scale*j]*(dim//2)) for j in range(-2, 3)]
    weights = random.uniform(subkey, (len(center),))**2
    weights = weights / weights.sum()
    target = GaussMixJax(jnp.array(center), weights, scale)

    # ===== generate problem =====
    key, subkey = random.split(key)
    y, A, Sigma_y, x_origin = generate_inverse_problem_gm(subkey, dim, dim_y, target, scale, kappa_range, snr_range)
    var_y = Sigma_y[0][0]
    U, D, V = jnp.linalg.svd(A, full_matrices=False)
    V = V.T
    eps = 1e-2
    lambda_max = D.max()**2
    T_denoise = np.log(1+var_y/lambda_max)/2
    covs = jnp.repeat(jnp.eye(dim)[None]*(target.var_scale**2*jnp.exp(-(T_denoise-eps)*2)+1-jnp.exp(-(T_denoise-eps)*2)), axis=0, repeats=target.n_center)
    prior_middle = MixtureSameFamily(
        mixing_distribution=Categorical(target.weights),
        component_distribution=MultivariateNormal(jnp.array(center)*jnp.exp(-(T_denoise-eps)), covariance_matrix=covs)
    )
    A_singular = (jnp.exp(2*(T_denoise-eps)) / (1 + var_y/D**2 - jnp.exp(2*(T_denoise-eps))))**(1/2)
    A_singular = jnp.diag(A_singular) @ V.T
    aux_singular = (jnp.exp(2*(T_denoise-eps))/((1-jnp.exp(2*(T_denoise-eps)))*D**2+var_y))**(1/2)
    obs_singular = jnp.diag(aux_singular) @ (jnp.exp(-(T_denoise-eps)) * (U.T @ y))
    n_step_denoise = math.ceil(T_denoise / T * n_steps)
    dts_denoise = jnp.ones(n_step_denoise) * (T_denoise-eps) / n_step_denoise
    print(f"T denoise: {T_denoise}, n_step_denoise: {n_step_denoise}")
    key, subkey = random.split(key)
    samples_true_full = target.sample_posterior(key, (n_samples,))
    plot_dim = 0
    nuts_steps = getattr(example_cfg, 'nuts_steps', 100)

    # ===== NUTS of posterior =====
    posterior_logprob_nuts = lambda x: - ((y - A @ x['loc']) ** 2).sum(axis=-1) / 2 / var_y + target.logprior_density(x['loc'], 0)
    samples_nuts, initial_positions = sample_nuts(dim, subkey, target, n_groups*n_particles, posterior_logprob_nuts, nuts_steps)
    if cfg.plot:
        samples_nuts_slice = samples_nuts[..., plot_dim].reshape([-1])
        samples_true = samples_true_full[:, plot_dim]
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        cmap = plt.get_cmap("tab10")
        sns.kdeplot(samples_nuts_slice, color='k', ax=ax, label='NUTS')
        sns.kdeplot(samples_true, color=cmap(0), ax=ax, label='True Posterior')
        sns.kdeplot(initial_positions[..., plot_dim].reshape([-1]), color=cmap(1), ax=ax, label='Initial')
        plt.legend();  plt.show()
        print(samples_nuts_slice.mean(), samples_nuts_slice.std())
        print(samples_true.mean(), samples_true.std())
    sample1 = samples_true_full[:n_particles]
    sample2 = samples_nuts
    key, wass_subkey = random.split(key)
    print("Sliced wasserstein distance of HMC: {}".format(sliced_wasserstein(wass_subkey, sample1, sample2, n_slices=100)))


    # ===== NUTS of boosted posterior =====
    posterior_logprob_nuts = lambda x: - ((obs_singular - A_singular @ x['loc']) ** 2).sum(axis=-1) / 2 + target.logprior_density(x['loc'], T_denoise-eps)
    key, subkey = random.split(key)
    samples_nuts_init, initial_positions = sample_nuts(dim, subkey, target, n_groups*n_particles, posterior_logprob_nuts, nuts_steps)
    cfg.reverse = True
    samples_nuts_init = samples_nuts_init.reshape((n_groups, n_particles, dim))
    key, subkey = random.split(key)
    keys = jax.random.split(subkey, n_groups)
    keys = jnp.array([jax.random.split(k, n_particles) for k in keys])
    samples_nuts_boost_full = diffusion_sampler(keys, samples_nuts_init, dts_denoise, cfg, target.score_fn)
    if cfg.plot:
        samples_nuts_slice = samples_nuts_boost_full[..., plot_dim].reshape([-1])
        samples_true = samples_true_full[:, plot_dim]
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        cmap = plt.get_cmap("tab10")
        sns.kdeplot(samples_nuts_slice, color='k', ax=ax, label='NUTS')
        sns.kdeplot(samples_true, color=cmap(0), ax=ax, label='True Posterior')
        sns.kdeplot(samples_nuts_init[..., plot_dim].reshape([-1]), color=cmap(1), ax=ax, label='NUTS for boosted prior')
        sns.kdeplot(initial_positions[..., plot_dim].reshape([-1]), color=cmap(2), ax=ax, label='Initial for NUTS boosted')
        plt.legend();  plt.show()
        print(samples_nuts_slice.mean(), samples_nuts_slice.std())
        print(samples_true.mean(), samples_true.std())
    sample1 = samples_true_full[:n_particles]
    sample2 = samples_nuts_boost_full.reshape((-1, dim))
    print("Sliced wasserstein distance of boosted HMC: {}".format(sliced_wasserstein(wass_subkey, sample1, sample2, n_slices=100)))

    if cfg.exp_suffix:
        save_path = os.path.join(cfg.save_path, "dx{}dy{}_{}".format(dim, dim_y, cfg.exp_suffix))
    else:
        save_path = os.path.join(cfg.save_path, "dx{}dy{}".format(dim, dim_y))
    os.makedirs(save_path, exist_ok=True)
    jnp.savez(
        os.path.join(save_path, 'rdkey{}'.format(cfg.random_key)),
        samples_true=samples_true_full,
        samples_nuts=samples_nuts,
        samples_nuts_boost=samples_nuts_boost_full
    )



if __name__ == "__main__":
    main()
