import pickle

import jax
import jax.numpy as jnp

from experiments.logisticRegression.mnist.load_mnist import mnist_dataset
from experiments.logisticRegression.utils import get_tgt_log_density
from variational.exponential_family import GenericMeanFieldNormalDistribution, MeanFieldNormalDistribution
from variational.ngd import ngd

OUTPUT_PATH = "./output"
jax.config.update("jax_enable_x64", True)
OP_key = jax.random.PRNGKey(0)


def experiment(keys, n_iter, n_samples, lr, OUTPUT_PATH="./output_mean_field", s="0"):
    flipped_predictors = mnist_dataset(return_test=False)
    N, dim = flipped_predictors.shape

    # Gaussian Prior
    my_prior_covariance = 25 * jnp.identity(dim)
    # my_prior_covariance = my_prior_covariance.at[0, 0].set(400)
    my_prior_covariance = jnp.diag(my_prior_covariance)
    my_prior_log_density = MeanFieldNormalDistribution(jnp.zeros(dim), my_prior_covariance).log_density
    tgt_log_density = get_tgt_log_density(flipped_predictors, my_prior_log_density)

    # Mean Field Gaussian Variational Family
    my_variational_family = GenericMeanFieldNormalDistribution(dimension=dim)
    sampling = my_variational_family.sampling_method
    sufficient_statistic = my_variational_family.sufficient_statistic
    sanity = my_variational_family.sanity

    upsilon_init = my_variational_family.get_upsilon(jnp.zeros(dim), jnp.ones(dim) * jnp.exp(-2))

    @jax.vmap
    def f(key):
        res = ngd(key, sampling, sufficient_statistic, tgt_log_density, upsilon_init, n_iter, n_samples,
                  lr_schedule=lr, sanity=sanity)
        return res

    res = f(keys)

    PARAMS = {'n_iter': n_iter, 'n_samples': n_samples, 'lr': lr}
    desc = "MNIST dataset, mean-field Gaussian, NGD"
    with open(
            f"{OUTPUT_PATH}/gaussian_meanfield_ngd_{n_iter}_{n_samples}_{lr if isinstance(lr, float) else "Seq"}_{s}.pkl",
            "wb") as f:
        pickle.dump({'desc': desc, 'PARAMS': PARAMS, 'res': res, 'all': None}, f)


if __name__ == "__main__":
    n_iter = int(1e2)
    n_samples = int(1e4)
    lr = 1e-3
    n_repetitions = 25
    for k in range(4):
        keys = jax.random.split(jax.random.PRNGKey(k), n_repetitions)
        experiment(keys, n_iter, n_samples, lr, "./output_mean_field", k)
