import pickle

import jax
from matplotlib import pyplot as plt

from sves.benchmarks.synthetic_benchmarks import DoubleBanana
from sample_motion_plan import sample_mcmc


if __name__ == "__main__":
    rng = jax.random.PRNGKey(123)
    rng_init, rng_run, rng_sample = jax.random.split(rng, 3)
    bench = DoubleBanana(rng_init)
    objective_fn, score_fn = bench.get_objective_derivative()
    ub = bench.upper_bounds
    lb = bench.lower_bounds
    num_samples = 1
    num_chains = 256
    sample_mcmc(
        rng_run,
        lambda x: objective_fn(x),
        num_samples=num_samples,
        num_chains=num_chains,
        num_dims=bench.dim,
        num_burnin_steps=int(1e5),
        storage_path="../sves/data/banana",
        min=-.1,
        max=.1,
        noise_scale=5e-2
    )

    # Plot samples
    try:
        with open(f"../sves/data/banana/gt_samples{num_samples * num_chains}.pkl", "rb") as handler:
            gt_samples = pickle.load(handler)
            bench.plot((gt_samples, None), lb, ub)
            plt.show()
    except Exception as e:
        print(e)
