import os
import pickle
from typing import Callable

import jax
import jax.numpy as jnp
from chex import PRNGKey, Array
import tensorflow_probability.substrates.jax as tfp
from matplotlib import pyplot as plt

from sves.benchmarks.simple_motion_plan import RamosPaper


def sample_mcmc(
    rng: PRNGKey,
    objective_fn: Callable,
    num_samples: int,
    num_dims: int,
    storage_path: str,
    thinning_factor: int = 6_000,
    num_burnin_steps: int = int(1e6),
    num_chains: int = 20,
    noise_scale: float = .225,
    min: float=-1.,
    max: float=1.,
) -> Array:
    # MH transition kernel
    mh = tfp.mcmc.RandomWalkMetropolis(
        lambda x: -objective_fn(x),  # Every benchmark is minimized in my set-up
        new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=noise_scale)
    )

    def trace_fn(_, pkr):
        return {
            'target_log_prob': pkr.accepted_results.target_log_prob,
            'is_accepted': pkr.is_accepted
        }  # Keep track of the log-prob and acceptance for diagnostics

    @jax.jit
    def run_chain(key: PRNGKey):
        rng_mcmc, rng_init = jax.random.split(key)
        initial_state = jax.random.uniform(rng_init, (num_chains, num_dims), minval=min, maxval=max)
        return tfp.mcmc.sample_chain(
            num_results=num_samples,
            current_state=initial_state,
            kernel=mh,
            num_steps_between_results=thinning_factor,
            num_burnin_steps=num_burnin_steps,
            seed=rng_mcmc,
            trace_fn=trace_fn
        )


    samples, trace = run_chain(rng)
    gt_samples = jnp.array(samples).reshape(-1, num_dims)

    # Store samples -- lazy code assumes UNIX-based file handles
    os.makedirs(storage_path, exist_ok=True)
    with open(f"{storage_path}/gt_samples{gt_samples.shape[0]}.pkl", "wb") as handler:
        pickle.dump(jnp.array(gt_samples), handler, protocol=pickle.HIGHEST_PROTOCOL)

    return gt_samples


if __name__ == "__main__":
    bench = RamosPaper(n_via=5)
    bench.sample(jax.random.PRNGKey(0), 256)
    objective_fn, score_fn = bench.get_objective_derivative()
    ub = bench.upper_bounds
    lb = bench.lower_bounds
    num_samples = 1
    num_chains = 256
    gt_samples = sample_mcmc(
        jax.random.PRNGKey(123),
        lambda x: jnp.exp(objective_fn(x)),
        num_samples=num_samples,
        num_chains=num_chains,
        num_dims=bench.dim,
        num_burnin_steps=int(1e5),
        storage_path="../sves/data/ramos_eval",
        min=lb,
        max=ub,
        noise_scale=.1
    )

    # Plot samples
    try:
        with open(f"../sves/data/ramos_eval/gt_samples{num_samples * num_chains}.pkl", "rb") as handler:
            gt_samples = pickle.load(handler)
            bench.plot((gt_samples, None), lb-2, ub+2)
            plt.show()
    except Exception as e:
        print(e)
