import jax
jax.config.update("jax_default_matmul_precision", "highest")
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import pandas as pd
from functools import partial
import time, sys, os, pickle
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
import gc

# core DEER + QDEER algorithms
import src
from src.algs.deer import seq1d

# simulation settings to collect trials over
B = [1, 2, 4, 8, 16, 32][::-1][int(sys.argv[1])]
L = [1_000, 2_000, 4_000, 8_000, 16_000, 32_000, 64_000][::-1][int(sys.argv[2])]
sequential = [True, False][int(sys.argv[3])]
seed = int(sys.argv[4])

######

# load in the German Credit Data (Whitened) + get our X and our y, and dimensions D
data = pd.read_csv("data/german-credit_d=24_whiten=True_camera-ready.csv").to_numpy()
X, y = data[:, 1:], data[:, 0]; D = X.shape[1]

# convert to jnp
X, y = jnp.asarray(X), jnp.asarray(y); D = X.shape[1]

######

# true accept-reject switch.
def accept(x):
  return x > 0  # non-differentiable

# differentiable proxy for accept-reject switch
def sigmoid_accept(x):
    zero = jax.nn.sigmoid(x) - jax.lax.stop_gradient(jax.nn.sigmoid(x)) # literally zero on fwd pass.
    return zero + jax.lax.stop_gradient(accept(x)) # returns accept function, but no grads tracked.

# GLOBAL: how many iterations of Q-D are we running? also, we DO NOT want full traces!
max_iter, full_trace = int(50 + 5e-4 * L), False

# for timing purposes + burnin
nwarmups, nreps, num_burnin = 3, 5, 5


#### A. SETTING UP OUR DISTRIBUTION, LOG-PROB, and GRAD-LOG-PROB.

# log-prob, also gradient of log-prob.
def log_prob(X, y, beta):
    return tfd.Normal(0., 1.).log_prob(beta).sum() + tfd.Bernoulli(X @ beta).log_prob(y).sum()
logp = partial(log_prob, X, y)

target_log_prob_and_grad = jax.value_and_grad(logp)

# choosing stepsize epsilon to hit ~80% acceptance probability.
D_latent, D_in, epsilon = D, D, 0.0015
params = {"epsilon" : epsilon} # store in dictionary.

#### B. DEFINING OUR Q-DEER FUNCTION

# Langevin dynamics with accept / reject
def fxn_for_deer(state, driver, params):

    # 1. get the stepsize parameter for MALA
    epsilon = params["epsilon"] 
    
    # 2. get the current state's log-prob and gradient of the log-prob
    logprob_state, grad_state = target_log_prob_and_grad(state)

    # 3. propose next state
    next_state = state + (epsilon * grad_state)
    next_state = next_state + jnp.sqrt(2.0 * epsilon) * driver[:-1]
    
    # 4. get the log-prob + grad-log-prob of the proposed next state
    logprob_nextstate, grad_nextstate = target_log_prob_and_grad(next_state)

    # 5. compute accept prob (computing numerator + denominator)
    num = logprob_nextstate + tfd.MultivariateNormalDiag(
        loc=next_state + (epsilon * grad_nextstate), 
        scale_diag=jnp.sqrt(2.0 * epsilon) * jnp.ones_like(state)).log_prob(state)
    den = logprob_state + tfd.MultivariateNormalDiag(
        loc=state+epsilon * grad_state, 
        scale_diag=jnp.sqrt(2.0 * epsilon) * jnp.ones_like(state)).log_prob(next_state)

    # 6. implement whether we accept/reject.
    u = driver[-1] # final dimension is uniform random variable
    g = sigmoid_accept(num - den - jnp.log(u))
    next_state = g*next_state + (1.0-g)*state

    return next_state


# herald status update
method = "Q-DEER" if sequential is False else "Sequential"
print(f"Starting {method} combination of B={B}, L={L}, seed={seed}")

#### C. GENERATING OUR EXTERNAL SOURCES OF NOISE + INITIAL CONDITIONS

# gen. init conditions + external noise sources: uniform acc/rej & MALA noisy-gradient step noise
key = jr.PRNGKey(seed) # JAX randomization requires explicitly specifying the keys.

'''
To get a better initialization for both sequential and Q-DEER, we run a few MALA steps.
'''
# 1. initial conditions s0 for running the sampler sequentially
initial_states_batched = jr.normal(key=key, shape=(B, D)) # also use as init for DEER
for _ in range(num_burnin):

    # get some new drivers
    key, skey = jr.split(key); in_drivers = jr.normal(skey, (B, D_in))
    key, skey = jr.split(key); in_unif = jr.uniform(skey, (B, 1))
    in_driver = jnp.concatenate((in_drivers, in_unif), axis=-1)

    # make some new steps
    initial_states_batched = jax.vmap(fxn_for_deer, in_axes=(0, 0, None))(
        initial_states_batched, in_driver, params)

# 2. MVN(0, I_D) drivers for the MALA noisy-gradient proposals
key, skey = jr.split(key) # can also specify num=1000 to get more keys.
mala_gaussian_drivers = 1.0 * jr.normal(skey, (B, L, D_in))

# 3. Unif(0, 1) drivers for the MALA accept/reject proposals
key, skey = jr.split(key)
mala_uniform_drivers = jr.uniform(skey, ((B, L, 1)))

# 4. combine our drivers together
drivers_batched = jnp.concatenate((mala_gaussian_drivers, mala_uniform_drivers), axis=-1)

#### D. DO OUR SEQUENTIAL SAMPLING PROCEDURE!
if sequential:

    # 1. create compatible function for jax.lax.scan
    def fxn_for_scan(state, driver):
        state = fxn_for_deer(state, driver, params)
        return state, state

    # 2. function for getting sequential scan output (no batching yet!) + then batch it up
    def unbatched_sequential(initial_state, drivers):
        _, out_states = jax.lax.scan(fxn_for_scan, initial_state, drivers[1:])
        return out_states
    batched_sequential = jax.jit(jax.vmap(unbatched_sequential)) # in_axes=(0, 0)
    outputs_sequential = batched_sequential(initial_states_batched, drivers_batched)
    outputs_sequential = jax.block_until_ready(outputs_sequential)

    # 3. save our sequential results
    np.savez_compressed(f"samples/SEQUENTIAL_B={B}_L={L}_seed={seed}.npz", samples=outputs_sequential)
    del outputs_sequential
    gc.collect()
    jax.clear_caches()

    # 4. timing trials: warmup + real trials.
    for _ in range(nwarmups):
        outputs_sequential = batched_sequential(initial_states_batched, drivers_batched)
        outputs_sequential = jax.block_until_ready(outputs_sequential)
    t0 = time.time()
    for _ in range(nreps):
        outputs_sequential = batched_sequential(initial_states_batched, drivers_batched)
        outputs_sequential = jax.block_until_ready(outputs_sequential)
    t1 = time.time()
    
    # 5. status update
    sequential_time = (t1 - t0) / nreps
    print(f"- Finished sequentially generating, saving, and timing samples: {sequential_time} seconds.")
    del outputs_sequential
    gc.collect()
    jax.clear_caches()
    
    # 6. update our logs
    row = [B, L, seed, "sequential", sequential_time]
    with open("logs_sequential.csv", "a") as file:
        file.write(",".join([str(entry) for entry in row]))
        file.write('\n')

#### E. DO OUR QUASI-DEER + HUTCHINSON'S PROCEDURE
else:

    # 1. update params with a key to support Hutchinson's
    key, skey = jr.split(key)
    params["key"] = skey

    # 2. initial guesses for s_1 ... s_T
    yinit_guess_batched = initial_states_batched[:,None,:] * jnp.ones((B, L-1, D_latent))

    # 3. basically just freezing params at params
    fxn = partial(fxn_for_deer, params=params)

    # 4. repeat process for quasi-deer
    def unbatched_qdeer(initial_state, drivers, yinit_guess):
        return seq1d(
            fxn_for_deer, initial_state, drivers[1:], 
            params, yinit_guess=yinit_guess, max_iter=max_iter, 
            quasi=True, qmem_efficient=True, full_trace=full_trace, clip_val=1.0)
    batched_qdeer = jax.jit(jax.vmap(unbatched_qdeer))

    # 5. get our outputs for Q-DEER only.
    outputs_qdeer = batched_qdeer(initial_states_batched, drivers_batched, yinit_guess_batched)
    outputs_qdeer = jax.block_until_ready(outputs_qdeer)
    np.savez_compressed(f"samples/QDEER_B={B}_L={L}_seed={seed}.npz", samples=outputs_qdeer[0])
    
    # 6. update our iteration-counter logs
    with open("convg_iters.pickle", "rb") as file:
        convg_iters = pickle.load(file)
    convg_iters[(B, L, seed)] = np.array(outputs_qdeer[1])
    with open("convg_iters.pickle", "wb") as file:
        pickle.dump(convg_iters, file)
        
    # 7. clean house.
    del outputs_qdeer
    gc.collect()
    jax.clear_caches()

    # 8. timing trials: warmup + real trials.
    for _ in range(nwarmups): # tqdm(, desc="Warm-up Q-DEER"):
        outputs_qdeer = batched_qdeer(initial_states_batched, drivers_batched, yinit_guess_batched)
        outputs_qdeer = jax.block_until_ready(outputs_qdeer)
    t0 = time.time()
    for _ in range(nreps):
        outputs_qdeer = batched_qdeer(initial_states_batched, drivers_batched, yinit_guess_batched)
        outputs_qdeer = jax.block_until_ready(outputs_qdeer)
    t1 = time.time()
    qdeer_time = (t1 - t0) / nreps
    
    # 9. update our timing logs
    row = [B, L, seed, "qdeer", qdeer_time]
    with open("logs_parallel.csv", "a") as file:
        file.write(",".join([str(entry) for entry in row]))
        file.write('\n')

    # 10. status update + save outputs
    print(f"- Finished running Q-DEER + generating, saving, and timing samples: {(t1 - t0) / nreps} seconds.")
    del outputs_qdeer
    gc.collect()
    jax.clear_caches()