import os
import math
import jax
import matplotlib.pyplot as plt
import seaborn as sns
import hydra
import numpy as np
import jax.numpy as jnp
from jax import random

from dist import GaussMixJax
from util import diffusion_sampler
from util import generate_inverse_problem_gm, sliced_wasserstein
from numpyro.distributions import MultivariateNormal, Categorical, MixtureSameFamily


@hydra.main(version_base=None, config_path="configs/", config_name="gmm")
def main(cfg):
    example_cfg = cfg.example
    key = jax.random.PRNGKey(cfg.random_key)
    n_groups = cfg.n_groups # each group is individual
    n_particles = example_cfg.n_particles
    n_steps = example_cfg.n_steps
    T = example_cfg.T
    dim = example_cfg.dim
    dim_y = example_cfg.dim_y
    kappa_range = example_cfg.kappa_range
    snr_range = example_cfg.snr_range
    noise = cfg.noise
    n_samples = cfg.n_samples
    dts = jnp.ones(n_steps) * T / n_steps

    key, subkey = random.split(key)
    center = []
    scale = 1
    for i in range(-2, 3):
        center += [jnp.array([-8.*scale*i, -8.*scale*j]*(dim//2)) for j in range(-2, 3)]
    weights = random.uniform(subkey, (len(center),))**2
    weights = weights / weights.sum()
    target = GaussMixJax(jnp.array(center), weights, scale)

    # ===== generate problem =====
    key, subkey = random.split(key)
    y, A, Sigma_y, x_origin = generate_inverse_problem_gm(subkey, dim, dim_y, target, scale, kappa_range, snr_range)
    var_y = Sigma_y[0][0]
    U, D, V = jnp.linalg.svd(A, full_matrices=False)
    V = V.T
    eps = 1e-2
    lambda_max = D.max()**2
    T_denoise = np.log(1+var_y/lambda_max)/2
    covs = jnp.repeat(jnp.eye(dim)[None]*(target.var_scale**2*jnp.exp(-(T_denoise-eps)*2)+1-jnp.exp(-(T_denoise-eps)*2)), axis=0, repeats=target.n_center)
    prior_middle = MixtureSameFamily(
        mixing_distribution=Categorical(target.weights),
        component_distribution=MultivariateNormal(jnp.array(center)*jnp.exp(-(T_denoise-eps)), covariance_matrix=covs)
    )
    A_singular = (jnp.exp(2*(T_denoise-eps)) / (1 + var_y/D**2 - jnp.exp(2*(T_denoise-eps))))**(1/2)
    A_singular = jnp.diag(A_singular) @ V.T
    aux_singular = (jnp.exp(2*(T_denoise-eps))/((1-jnp.exp(2*(T_denoise-eps)))*D**2+var_y))**(1/2)
    obs_singular = jnp.diag(aux_singular) @ (jnp.exp(-(T_denoise-eps)) * (U.T @ y))
    n_step_denoise = math.ceil(T_denoise / T * n_steps)
    dts_denoise = jnp.ones(n_step_denoise) * (T_denoise-eps) / n_step_denoise
    print(f"T denoise: {T_denoise}, n_step_denoise: {n_step_denoise}")
    key, subkey = random.split(key)
    samples_true_full = target.sample_posterior(key, (n_samples,))
    plot_dim = 0

    # ===== Langevin of posterior =====
    key, subkey = random.split(key)
    x0s = random.normal(subkey, (n_groups, n_particles, dim))
    posterior_score = lambda z, t: target.score_fn(z, 0) * 1 - A.T @ (A @ z - y) / var_y * 1
    cfg.reverse = False
    dts = jnp.ones(n_steps) * T / n_steps * 5
    key, subkey = random.split(key)
    keys = jax.random.split(subkey, n_groups)
    keys = jnp.array([jax.random.split(k, n_particles) for k in keys])
    samples_langevin_full = diffusion_sampler(keys, x0s, dts, cfg, posterior_score)
    if cfg.plot:
        samples_langevin = samples_langevin_full[..., plot_dim].reshape([-1])
        samples_true = samples_true_full[:, plot_dim]
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        cmap = plt.get_cmap("tab10")
        sns.kdeplot(samples_langevin, color='k', ax=ax, label='Langevin')
        sns.kdeplot(samples_true, color=cmap(0), ax=ax, label='True Posterior')
        sns.kdeplot(x0s[..., plot_dim].reshape([-1]), color=cmap(1), ax=ax, label='Prior')
        plt.legend(); plt.show()
        print(samples_langevin.mean(), samples_langevin.std())
        print(samples_true.mean(), samples_true.std())
    sample1 = samples_true_full
    sample2 = samples_langevin_full.reshape((-1, dim))
    key, wass_subkey = random.split(key)
    print("Sliced wasserstein distance of LMC: {}".format(sliced_wasserstein(wass_subkey, sample1, sample2, n_slices=100)))

    # ===== Langevin of boosted posterior =====
    posterior_score = lambda z, t: target.score_fn(z, (T_denoise-eps)) - A_singular.T @ (A_singular @ z - obs_singular) * 1
    cfg.reverse = False
    dts = jnp.ones(n_steps) * T / n_steps * 5
    samples_langevin_boost_full = diffusion_sampler(keys, x0s, dts, cfg, posterior_score)
    cfg.reverse = True
    key, subkey = random.split(key)
    keys = jax.random.split(subkey, n_groups)
    keys = jnp.array([jax.random.split(k, n_particles) for k in keys])
    samples_langevin_boost_full = diffusion_sampler(keys, samples_langevin_boost_full, dts_denoise, cfg, target.score_fn)
    if cfg.plot:
        samples_langevin = samples_langevin_boost_full[..., plot_dim].reshape([-1])
        samples_true = samples_true_full[:, plot_dim]
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        cmap = plt.get_cmap("tab10")
        sns.kdeplot(samples_langevin, color='k', ax=ax, label='Langevin')
        sns.kdeplot(samples_true, color=cmap(0), ax=ax, label='True Posterior')
        sns.kdeplot(x0s[..., plot_dim].reshape([-1]), color=cmap(1), ax=ax, label='Prior')
        plt.legend(); plt.show()
        print(samples_langevin.mean(), samples_langevin.std())
        print(samples_true.mean(), samples_true.std())
    sample1 = samples_true_full
    sample2 = samples_langevin_boost_full.reshape((-1, dim))
    print("Sliced wasserstein distance of boosted LMC: {}".format(sliced_wasserstein(wass_subkey, sample1, sample2, n_slices=100)))

    if cfg.exp_suffix:
        save_path = os.path.join(cfg.save_path, "dx{}dy{}_{}".format(dim, dim_y, cfg.exp_suffix))
    else:
        save_path = os.path.join(cfg.save_path, "dx{}dy{}".format(dim, dim_y))
    os.makedirs(save_path, exist_ok=True)
    jnp.savez(
        os.path.join(save_path, 'rdkey{}'.format(cfg.random_key)),
        samples_true=samples_true_full,
        samples_langevin=samples_langevin_full,
        samples_langevin_boost=samples_langevin_boost_full
    )



if __name__ == "__main__":
    main()
