import jax
jax.config.update('jax_default_matmul_precision', 'highest')

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', dest='seed', type=int, default=13)
args = parser.parse_args()

import jax.numpy as jnp
import jax.random as jr
from jax.scipy.special import logsumexp

from functools import partial

import argparse
import time
import os
import numpy as np

from src.elk.algs.qdeer_leapfrog import seq1d
from src.elk.algs.elk import elk_alg, quasi_elk

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

def accept(x):
  return x > 0  # non-differentiable

def sigmoid_accept(x):
  # This function is a step function in forward pass and sigmoid in backwards pass.
  zero = jax.nn.sigmoid(x) - jax.lax.stop_gradient(jax.nn.sigmoid(x))
  return zero + jax.lax.stop_gradient(accept(x))
import tensorflow_datasets as tfds

def load_dataset():
    """
    Dataset loading copied from inference_gym package
    https://github.com/tensorflow/probability/blob/main/spinoffs/inference_gym/inference_gym/internal/data.py
    """
    def _normalize_zero_mean_one_std(train, test):
        train = np.asarray(train)
        test = np.asarray(test)
        train_mean = train.mean(0, keepdims=True)
        train_std = train.std(0, keepdims=True)
        return (train - train_mean) / train_std, (test - train_mean) / train_std

    train_fraction=1.
    num_points = 1000
    num_train = int(num_points * train_fraction)
    num_test = num_points - num_train
    num_features = 24

    dataset = tfds.load('german_credit_numeric:1.*.*')
    features = []
    labels = []
    for entry in tfds.as_numpy(dataset)['train']:
        features.append(entry['features'])
        # We're reversing the labels to match what's in the original dataset,
        # rather the TFDS encoding.
        labels.append(1 - entry['label'])
    features = np.stack(features, axis=0)
    labels = np.stack(labels, axis=0)

    train_features = features[:num_train]
    test_features = features[num_train:]
    train_features, test_features = _normalize_zero_mean_one_std(
        train_features, test_features)
    train_labels = labels[:num_train].astype(np.int32)
    return train_features, train_labels

X, y = load_dataset()

def whiten(X):
    X_centered = X - np.mean(X, axis=0)
    cov = np.cov(X_centered, rowvar=False)
    eigenvalues, eigenvectors = np.linalg.eigh(cov)
    epsilon = 1e-10
    D_inv_sqrt = np.diag(1.0 / np.sqrt(eigenvalues + epsilon))
    whitening_matrix = eigenvectors @ D_inv_sqrt @ eigenvectors.T
    X_whitened = X_centered @ whitening_matrix
    return X_whitened
X = whiten(X)
X = jnp.asarray(X)
X = jnp.hstack((X, jnp.ones((1000,1))))
y = jnp.asarray(y)

def joint_log_prob(x, y, beta):
    lp = tfd.Normal(0., 1.).log_prob(beta).sum()
    logits = x @ beta
    lp += tfd.Bernoulli(logits).log_prob(y).sum()
    return lp
def unconstrained_joint_log_prob(x, y, z):
    return joint_log_prob(x, y, z)
logp = partial(unconstrained_joint_log_prob, X, y)
D = X.shape[-1]

target_log_prob_and_grad = jax.value_and_grad(logp)

# Parallel Leapfrog for DEER
def deer_leapfrog(state, driver, params):
    # Assumes you start and end with half-step corrections to momentum
    # Add half step before running iteration
    # Subtract half step after running iteration
    step_size = params["epsilon"] # epsilon is discretization
    z, m = jnp.split(state, 2)
    z += step_size * m
    _, tlp_grad = target_log_prob_and_grad(z)
    m += step_size * tlp_grad
    next_state = jnp.concatenate((z, m))
    return next_state

def hmc_step_seq(target_log_prob_and_grad, num_leapfrog_steps, step_size, z, seed):
    m_seed, mh_seed = jax.random.split(seed)
    tlp, tlp_grad = target_log_prob_and_grad(z)
    m = jax.random.normal(m_seed, z.shape)
    energy = 0.5 * jnp.square(m).sum() - tlp
    # start with half-step of momentum
    m += 0.5 * step_size * tlp_grad
    init_state = jnp.concatenate((z, m))
    params = {'epsilon':step_size}
    def fxn_for_scan(state, driver):
        state = deer_leapfrog(state, driver, params)
        return state, None
    drivers = jnp.zeros((num_leapfrog_steps,0))
    new_state, _ = jax.lax.scan(fxn_for_scan, init_state, drivers[1:])
    new_z, new_m = jnp.split(new_state, 2)
    new_tlp, new_tlp_grad = target_log_prob_and_grad(new_z)
    # end with backward half-step of momentum
    new_m -= 0.5 * step_size * new_tlp_grad 
    new_energy = 0.5 * jnp.square(new_m).sum() - new_tlp
    log_accept_ratio = energy - new_energy
    is_accepted = jnp.log(jax.random.uniform(mh_seed, [])) < log_accept_ratio
    # select the proposed state if accepted
    z = jnp.where(is_accepted, new_z, z)
    hmc_output = {"z": z,
                  "is_accepted": is_accepted,
                  "log_accept_ratio": log_accept_ratio}
    # hmc_output["z"] has shape [num_dimensions]
    return z, hmc_output

def hmc_step_par(target_log_prob_and_grad, num_leapfrog_steps, step_size, z, seed):
    m_seed, mh_seed, key = jax.random.split(seed, 3)
    tlp, tlp_grad = target_log_prob_and_grad(z)
    m = jax.random.normal(m_seed, z.shape)
    energy = 0.5 * jnp.square(m).sum() - tlp
    # start with half-step of momentum
    m += 0.5 * step_size * tlp_grad
    init_state = jnp.concatenate((z, m))
    params = {'epsilon':step_size, 'key':key, 'mass_diag':jnp.ones((z.shape[0],))}
    drivers = jnp.zeros((num_leapfrog_steps,0))
    yinit_guess = init_state * jnp.ones((num_leapfrog_steps-1, 2*z.shape[0]))
    closed_deer = partial(seq1d, deer_leapfrog, init_state,
                          drivers[1:], params, logp, yinit_guess=yinit_guess, 
                          max_iter=num_leapfrog_steps, qmem_efficient=True, full_trace=False)
    new_state = closed_deer()[0][-1] 
    new_z, new_m = jnp.split(new_state, 2)
    new_tlp, new_tlp_grad = target_log_prob_and_grad(new_z)
    # end with backward half-step of momentum
    new_m -= 0.5 * step_size * new_tlp_grad 
    new_energy = 0.5 * jnp.square(new_m).sum() - new_tlp
    log_accept_ratio = energy - new_energy
    is_accepted = jnp.log(jax.random.uniform(mh_seed, [])) < log_accept_ratio
    # select the proposed state if accepted
    z = jnp.where(is_accepted, new_z, z)
    hmc_output = {"z": z,
                  "is_accepted": is_accepted,
                  "log_accept_ratio": log_accept_ratio}
    # hmc_output["z"] has shape [num_dimensions]
    return z, hmc_output

vmap_hmc_step_seq = jax.vmap(hmc_step_seq, in_axes=(None, None, None, 0, 0))
vmap_hmc_step_par = jax.vmap(hmc_step_par, in_axes=(None, None, None, 0, 0))

def hmc(target_log_prob_and_grad, num_leapfrog_steps, step_size, num_steps, num_chains, z, seed, use_parallel=False):
    # create a seed for each step
    assert num_chains == z.shape[0]
    seeds = jax.random.split(seed, (num_steps, num_chains))
    # this will repeatedly run hmc_step and accumulate the outputs
    if use_parallel:
        _, hmc_output = jax.lax.scan(
            partial(vmap_hmc_step_par, target_log_prob_and_grad, num_leapfrog_steps, step_size),
            z, seeds)
    else:
        _, hmc_output = jax.lax.scan(
            partial(vmap_hmc_step_seq, target_log_prob_and_grad, num_leapfrog_steps, step_size),
            z, seeds)
    # hmc_output["z"] now has shape [num_steps, num_dimensions]
    return hmc_output

chains_sweep = [4,]
leapfrog_sweep = [4, 8, 12, 16, 20, 24, 32, 40, 64, 96]
step_sweep = [0.005, 0.0075, 0.01, 0.02, 0.03, 0.04, 0.05, 0.075, 0.1]
key = jr.PRNGKey(args.seed)
outputs = []

import arviz as az
import time

import gc

for num_chains in chains_sweep:
  for num_leapfrog_steps in leapfrog_sweep:
    for step_size in step_sweep:
      print("# Chains: ", num_chains, ", Leapfrog: ", num_leapfrog_steps, ", Step Size: ", step_size)
      key, skey = jr.split(key)
      z_init = 0.1 * jr.normal(key, (num_chains,D,))
      key, seed = jr.split(key)
      num_samples = 2000
      
      run_seq = jax.jit(partial(hmc, target_log_prob_and_grad, num_leapfrog_steps, step_size, num_samples, num_chains, z_init, seed))
      run_par = jax.jit(partial(hmc, target_log_prob_and_grad, num_leapfrog_steps, step_size, num_samples, num_chains, z_init, seed, use_parallel=True))
      out = run_seq()
      del out
      out = run_par()
      jax.block_until_ready(out)
      del out

      t0 = time.time() 
      hmc_output_seq = run_seq()
      jax.block_until_ready(hmc_output_seq)
      t1 = time.time()
      time_elapsed_seq = t1 - t0
      
      jax.block_until_ready(time_elapsed_seq)
      t0 = time.time()
      hmc_output_par = run_par()
      jax.block_until_ready(hmc_output_par)
      t1 = time.time()
      time_elapsed_par = t1 - t0

      seq_samples = jnp.swapaxes(hmc_output_seq['z'], 0, 1)
      par_samples = jnp.swapaxes(hmc_output_par['z'], 0, 1)
      accept_ratio_seq = jnp.sum(hmc_output_seq['is_accepted']) / (num_samples*num_chains)
      accept_ratio_par = jnp.sum(hmc_output_par['is_accepted']) / (num_samples*num_chains)
      print("Accept Ratio Seq: ", accept_ratio_seq)
      print("Accept Ratio Par: ", accept_ratio_par)

      print(f"seq time: {time_elapsed_seq:.3e} s")
      print(f"deer time: {time_elapsed_par:.3e} s")

      seq_error = jnp.max(jnp.abs(hmc_output_seq['z'] - hmc_output_par['z']))
      dataset = az.convert_to_inference_data(np.array(seq_samples))
      ess_seq = np.mean(az.ess(dataset).x.values)
      dataset = az.convert_to_inference_data(np.array(par_samples))
      ess_par = np.mean(az.ess(dataset).x.values)

      print("Seq ESS/s: ", ess_seq/time_elapsed_seq)
      print("Par ESS/s: ", ess_par/time_elapsed_par)
      
      corr01_par = jnp.corrcoef(par_samples[0][:,0], par_samples[0][:,1])[0,1]
      print("Corr: ", corr01_par)
      corr01_seq = jnp.corrcoef(seq_samples[0][:,0], seq_samples[0][:,1])[0,1]
      out_dict = {}
      out_dict['D'] = D
      out_dict['num_chains'] = num_chains
      out_dict['num_leapfrog_steps'] = num_leapfrog_steps
      out_dict['step_size'] = step_size
      out_dict['time_elapsed_seq'] = time_elapsed_seq
      out_dict['time_elapsed_par'] = time_elapsed_par
      out_dict['ess_seq'] = ess_seq
      out_dict['ess_par'] = ess_par
      out_dict['seq_error'] = seq_error
      out_dict['accept_ratio_seq'] = accept_ratio_seq
      out_dict['accept_ratio_par'] = accept_ratio_par
      out_dict['par_corr'] = corr01_par
      out_dict['seq_corr'] = corr01_seq
      outputs.append(out_dict)

      gc.collect()
      jax.clear_caches()

"""
import pickle
filename = ""
with open(filename, 'wb') as file:
    pickle.dump(outputs, file)
"""
