import jax
jax.config.update('jax_default_matmul_precision', 'highest')

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax.numpy as jnp
import jax.random as jr
from jax.scipy.special import logsumexp

from functools import partial

import argparse
import time
import os
import numpy as np

from src.elk.algs.deer import seq1d
from src.elk.algs.elk import elk_alg, quasi_elk

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import tensorflow_datasets as tfds

import gc
import arviz as az

def load_dataset():
    """
    Dataset loading copied from inference_gym package
    https://github.com/tensorflow/probability/blob/main/spinoffs/inference_gym/inference_gym/internal/data.py
    """
    def _normalize_zero_mean_one_std(train, test):
        train = np.asarray(train)
        test = np.asarray(test)
        train_mean = train.mean(0, keepdims=True)
        train_std = train.std(0, keepdims=True)
        return (train - train_mean) / train_std, (test - train_mean) / train_std

    train_fraction=1.
    num_points = 1000
    num_train = int(num_points * train_fraction)
    num_test = num_points - num_train
    num_features = 24

    dataset = tfds.load('german_credit_numeric:1.*.*')
    features = []
    labels = []
    for entry in tfds.as_numpy(dataset)['train']:
        features.append(entry['features'])
        # We're reversing the labels to match what's in the original dataset,
        # rather the TFDS encoding.
        labels.append(1 - entry['label'])
    features = np.stack(features, axis=0)
    labels = np.stack(labels, axis=0)

    train_features = features[:num_train]
    test_features = features[num_train:]
    train_features, test_features = _normalize_zero_mean_one_std(
        train_features, test_features)
    train_labels = labels[:num_train].astype(np.int32)
    return train_features, train_labels

def whiten(X):
    X_centered = X - np.mean(X, axis=0)
    cov = np.cov(X_centered, rowvar=False)
    eigenvalues, eigenvectors = np.linalg.eigh(cov)
    epsilon = 1e-10
    D_inv_sqrt = np.diag(1.0 / np.sqrt(eigenvalues + epsilon))
    whitening_matrix = eigenvectors @ D_inv_sqrt @ eigenvectors.T
    X_whitened = X_centered @ whitening_matrix
    return X_whitened

def run_timing(X, y, B, L, key, num_leap, step_size, damp_factor=1.0):
        
    def joint_log_prob(x, y, beta):
        lp = tfd.Normal(0., 1.).log_prob(beta).sum()
        logits = x @ beta
        lp += tfd.Bernoulli(logits).log_prob(y).sum()
        return lp
    def unconstrained_joint_log_prob(x, y, z):
        return joint_log_prob(x, y, z)
    logp = partial(unconstrained_joint_log_prob, X, y)
    D = X.shape[-1]

    target_log_prob_and_grad = jax.value_and_grad(logp)

    def accept(x):
        return x > 0  # non-differentiable

    def sigmoid_accept(x):
        # This function is a step function in forward pass and sigmoid in backwards pass.
        zero = jax.nn.sigmoid(x) - jax.lax.stop_gradient(jax.nn.sigmoid(x))
        return zero + jax.lax.stop_gradient(accept(x))

    def scan_leapfrog(state, step_size):
        # Assumes you start and end
        # with half-step corrections to momentum
        # Add half step before running iteration
        # Subtract half step after running iteration
        z, m = jnp.split(state, 2)
        z += step_size * m
        _, tlp_grad = target_log_prob_and_grad(z)
        m += step_size * tlp_grad
        next_state = jnp.concatenate((z, m))
        return next_state

    def fxn_for_deer(state, driver, params):
        seed = driver
        z = state
        step_size = params['epsilon'] #params = {'epsilon':step_size}

        m_seed, mh_seed = jax.random.split(seed)
        tlp, tlp_grad = target_log_prob_and_grad(z)
        m = jax.random.normal(m_seed, z.shape)
        energy = 0.5 * jnp.square(m).sum() - tlp
        # start with half-step of momentum
        m += 0.5 * step_size * tlp_grad
        init_state = jnp.concatenate((z, m))
        new_state = jax.lax.fori_loop(0, params['num_leapfrog_steps'],
            lambda i, state : scan_leapfrog(state, step_size),
            init_state)
        new_z, new_m = jnp.split(new_state, 2)
        new_tlp, new_tlp_grad = target_log_prob_and_grad(new_z)
        # end with backward half-step of momentum
        new_m -= 0.5 * step_size * new_tlp_grad
        new_energy = 0.5 * jnp.square(new_m).sum() - new_tlp
        log_accept_ratio = energy - new_energy

        # accept-reject
        u = jax.random.uniform(mh_seed, [])
        g = sigmoid_accept(log_accept_ratio-jnp.log(u))
        z = g*new_z + (1.0-g)*z
        return z

    D_latent = D
    D_in = D

    key, skey = jr.split(key)
    initial_state = 0. + 1.0* jr.normal(skey, (B,D))
    drivers = jr.split(key, (B,L))

    epsilon = step_size 
    num_leapfrog_steps = num_leap
    params = {}
    params["epsilon"] = epsilon
    params["num_leapfrog_steps"] = num_leapfrog_steps

    # run sequential
    def fxn_for_scan(state, driver):
        state = fxn_for_deer(state, driver, params)
        return state, state

    # warmup sequentially
    for i in range(5):
        in_keys = jr.split(key, (B,))
        initial_state = jax.vmap(fxn_for_deer, in_axes=(0,0,None))(initial_state, in_keys, params)

    @jax.jit
    def _run_model(initial_state, drivers):
        _, out_states = jax.lax.scan(fxn_for_scan, initial_state, drivers[1:])
        return out_states
    run_model = jax.jit(jax.vmap(_run_model))
    out_states = run_model(initial_state, drivers)

    accept_ratio = 1.0 - np.sum(out_states[:,1:,0]==out_states[:,:-1,0]) / (L-1) / B
    print("Accept ratio: ", accept_ratio)

    fxn = partial(fxn_for_deer, params=params)
    max_deer_iter = 500

    params["key"] = jr.PRNGKey(13)
    yinit_guess = initial_state[:,None, :] * jnp.ones((B,L-1, D))

    @jax.jit
    def _run_qdeer(initial_state, drivers, yinit_guess):
        return seq1d(
            fxn_for_deer, initial_state, drivers[1:], params, yinit_guess=yinit_guess,
            max_iter=max_deer_iter, quasi=True, qmem_efficient=True, full_trace=False, clip_val=1.0, damp_factor=damp_factor)
    run_qdeer = jax.jit(jax.vmap(_run_qdeer))
    outputs_qdeer = run_qdeer(initial_state, drivers, yinit_guess)
    print("qDEER Iters: ", outputs_qdeer[-1])
    abs_error = jnp.max(jnp.abs(outputs_qdeer[0]-out_states), axis=(1,2))
    print("qDEER Errors: ", abs_error)
    del outputs_qdeer 

    nwarmups=2
    nreps=3

    # run timing
    # Warm-up phase
    for _ in range(nwarmups):
        out_states = run_model(initial_state, drivers)
        jax.block_until_ready(out_states)

    # Benchmark phase
    t0 = time.time()
    for _ in range(nreps):
        out_states = run_model(initial_state, drivers)
        jax.block_until_ready(out_states)
    t1 = time.time()
    time_elapsed_seq = (t1 - t0) / nreps
    print(f"sequential time: {time_elapsed_seq:.3e} s")
    dataset = az.convert_to_inference_data(np.array(out_states))
    ess_seq = az.ess(dataset).x.values
    del out_states
    gc.collect()

    # Warm-up phase
    for _ in range(nwarmups):
        outputs_deer = run_qdeer(initial_state, drivers, yinit_guess)
        jax.block_until_ready(outputs_deer)

    # Benchmark phase
    t0 = time.time()
    for _ in range(nreps):
        outputs_deer = run_qdeer(initial_state, drivers, yinit_guess)
        jax.block_until_ready(outputs_deer)
    t1 = time.time()
    time_elapsed = (t1 - t0) / nreps
    print(f"quasi deer time: {time_elapsed:.3e} s")

    ess_seq = az.ess(dataset).x.values
    dataset = az.convert_to_inference_data(np.array(outputs_deer[0]))
    ess_par = az.ess(dataset).x.values
    print("Mean ESS/s Par: ", np.mean(ess_par)/time_elapsed)

    outputs = {}
    outputs['time_seq'] = time_elapsed_seq 
    outputs['time_qdeer'] = time_elapsed 
    outputs['iters_deer'] = outputs_deer[1]
    outputs['abs_error_deer'] = abs_error #jnp.max(jnp.abs(outputs_qdeer[0]-out_states), axis=(1,2))
    outputs['damp_factor'] = damp_factor
    outputs['batch_size'] = B
    outputs['step_size'] = epsilon 
    outputs['num_samples'] = L 
    outputs['accept_ratio'] = accept_ratio
    outputs['max_iter'] = max_deer_iter
    outputs['ess_seq'] = ess_seq
    outputs['ess_par'] = ess_par

    return outputs 

X, y = load_dataset()
X = whiten(X)
X = jnp.asarray(X)
X = jnp.hstack((X, jnp.ones((X.shape[0], 1))))
y = jnp.asarray(y)

batch_sizes = [4,]
all_num_samples = [1000,]
key = jr.PRNGKey(100822)
num_runs = 10
all_outputs = []

for num_samples in all_num_samples:
    for batch_size in batch_sizes:
        for j in range(num_runs):
            key, skey = jr.split(key)
            print("Running L: " , num_samples, ", B: ", batch_size, ", Run: ", j)
            outputs = run_timing(X, y, batch_size, num_samples, skey, 4, 0.04, damp_factor=0.4)
            all_outputs.append(outputs)
            gc.collect()
            jax.clear_caches()

"""
import pickle
# Save the dictionary to a file
filename = ""
with open(filename, 'wb') as file:
    pickle.dump(all_outputs, file)
"""
