import jax
jax.config.update("jax_default_matmul_precision", "highest")
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import pandas as pd
from functools import partial
import time, sys, os, pickle
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
import gc

# core DEER + QDEER algorithms
import src
from src.algs.deer import seq1d

# simulation settings to collect trials over
B = [1, 2, 4, 8, 16][::-1][int(sys.argv[1])]
L = [1_000, 2_000, 4_000, 8_000, 16_000, 32_000, 64_000][::-1][int(sys.argv[2])]
seed = int(sys.argv[3])

######

# load in the German Credit Data (Whitened) + get our X and our y, and dimensions D
data = pd.read_csv("data/german-credit_d=24_whiten=True_camera-ready.csv").to_numpy()
X, y = data[:, 1:], data[:, 0]; D = X.shape[1]

# convert to jnp
X, y = jnp.asarray(X), jnp.asarray(y); D = X.shape[1]

######

# true accept-reject switch.
def accept(x):
  return x > 0  # non-differentiable

# differentiable proxy for accept-reject switch
def sigmoid_accept(x):
    zero = jax.nn.sigmoid(x) - jax.lax.stop_gradient(jax.nn.sigmoid(x)) # literally zero on fwd pass.
    return zero + jax.lax.stop_gradient(accept(x)) # returns accept function, but no grads tracked.

# GLOBAL: how many iterations of Q-D are we running? also, we DO NOT want full traces!
max_iter, full_trace = int(50 + 5e-4 * L), False

# for timing purposes + burnin
nwarmups, nreps, num_burnin = 3, 5, 5


#### A. SETTING UP OUR DISTRIBUTION, LOG-PROB, and GRAD-LOG-PROB.

# log-prob, also gradient of log-prob.
def log_prob(X, y, beta):
    return tfd.Normal(0., 1.).log_prob(beta).sum() + tfd.Bernoulli(X @ beta).log_prob(y).sum()
logp = partial(log_prob, X, y)

target_log_prob_and_grad = jax.value_and_grad(logp)

# choosing stepsize epsilon to hit ~80% acceptance probability.
D_latent, D_in, epsilon = D, D, 0.0015
params = {"epsilon" : epsilon} # store in dictionary.

#### B. DEFINING OUR Q-DEER FUNCTION

# Langevin dynamics with accept / reject
def fxn_for_deer(state, driver, params):

    # 1. get the stepsize parameter for MALA
    epsilon = params["epsilon"] 
    
    # 2. get the current state's log-prob and gradient of the log-prob
    logprob_state, grad_state = target_log_prob_and_grad(state)

    # 3. propose next state
    next_state = state + (epsilon * grad_state)
    next_state = next_state + jnp.sqrt(2.0 * epsilon) * driver[:-1]
    
    # 4. get the log-prob + grad-log-prob of the proposed next state
    logprob_nextstate, grad_nextstate = target_log_prob_and_grad(next_state)

    # 5. compute accept prob (computing numerator + denominator)
    num = logprob_nextstate + tfd.MultivariateNormalDiag(
        loc=next_state + (epsilon * grad_nextstate), 
        scale_diag=jnp.sqrt(2.0 * epsilon) * jnp.ones_like(state)).log_prob(state)
    den = logprob_state + tfd.MultivariateNormalDiag(
        loc=state+epsilon * grad_state, 
        scale_diag=jnp.sqrt(2.0 * epsilon) * jnp.ones_like(state)).log_prob(next_state)

    # 6. implement whether we accept/reject.
    u = driver[-1] # final dimension is uniform random variable
    g = sigmoid_accept(num - den - jnp.log(u))
    next_state = g*next_state + (1.0-g)*state

    return next_state


# herald status update
print(f"Starting Q-DEER (Ablation) combination of B={B}, L={L}, seed={seed}")

#### C. GENERATING OUR EXTERNAL SOURCES OF NOISE + INITIAL CONDITIONS

# gen. init conditions + external noise sources: uniform acc/rej & MALA noisy-gradient step noise
key = jr.PRNGKey(seed) # JAX randomization requires explicitly specifying the keys.

'''
To get a better initialization for both sequential and Q-DEER, we run a few MALA steps.
'''
# 1. initial conditions s0 for running the sampler sequentially
initial_states_batched = jr.normal(key=key, shape=(B, D)) # also use as init for DEER
for _ in range(num_burnin):

    # get some new drivers
    key, skey = jr.split(key); in_drivers = jr.normal(skey, (B, D_in))
    key, skey = jr.split(key); in_unif = jr.uniform(skey, (B, 1))
    in_driver = jnp.concatenate((in_drivers, in_unif), axis=-1)

    # make some new steps
    initial_states_batched = jax.vmap(fxn_for_deer, in_axes=(0, 0, None))(
        initial_states_batched, in_driver, params)

# 2. MVN(0, I_D) drivers for the MALA noisy-gradient proposals
key, skey = jr.split(key) # can also specify num=1000 to get more keys.
mala_gaussian_drivers = 1.0 * jr.normal(skey, (B, L, D_in))

# 3. Unif(0, 1) drivers for the MALA accept/reject proposals
key, skey = jr.split(key)
mala_uniform_drivers = jr.uniform(skey, ((B, L, 1)))

# 4. combine our drivers together
drivers_batched = jnp.concatenate((mala_gaussian_drivers, mala_uniform_drivers), axis=-1)

#### D. DO OUR QUASI-DEER + ABLATION PROCEDURE (MEMORY EFFICIENT = FALSE).

# 1. update params with a key to support Hutchinson's
key, skey = jr.split(key)
params["key"] = skey

# 2. initial guesses for s_1 ... s_T
yinit_guess_batched = initial_states_batched[:,None,:] * jnp.ones((B, L-1, D_latent))

# 3. basically just freezing params at params
fxn = partial(fxn_for_deer, params=params)

# 4. repeat process for quasi-deer
def unbatched_qdeer(initial_state, drivers, yinit_guess):
    return seq1d(
        fxn_for_deer, initial_state, drivers[1:], 
        params, yinit_guess=yinit_guess, max_iter=max_iter, 
        quasi=True, qmem_efficient=False, full_trace=full_trace, clip_val=1.0)
batched_qdeer = jax.jit(jax.vmap(unbatched_qdeer))

# 8. timing trials: warmup + real trials.
for _ in range(nwarmups): # tqdm(, desc="Warm-up Q-DEER"):
    outputs_qdeer = batched_qdeer(initial_states_batched, drivers_batched, yinit_guess_batched)
    outputs_qdeer = jax.block_until_ready(outputs_qdeer)
t0 = time.time()
for _ in range(nreps):
    outputs_qdeer = batched_qdeer(initial_states_batched, drivers_batched, yinit_guess_batched)
    outputs_qdeer = jax.block_until_ready(outputs_qdeer)
t1 = time.time()
qdeer_time = (t1 - t0) / nreps

# 9. update our timing logs
row = [B, L, seed, "qdeer_ablation", qdeer_time]
with open("logs_parallel_ablation.csv", "a") as file:
    file.write(",".join([str(entry) for entry in row]))
    file.write('\n')

# 10. status update + save outputs
print(f"- Finished running Q-DEER (Ablation) + generating, saving, and timing samples: {(t1 - t0) / nreps} seconds.")
del outputs_qdeer
gc.collect()
jax.clear_caches()