import jax
jax.config.update('jax_default_matmul_precision', 'highest')

import jax.numpy as jnp
import jax.random as jr
from jax.scipy.special import logsumexp

from functools import partial

import argparse
import time
import os
import numpy as np

from src.elk.algs.deer import seq1d
from src.elk.algs.elk import elk_alg, quasi_elk

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

def run_gibbs(B, L, in_key):
    print("Batch: ", B, "L: ", L)

    ## Sample Data (modified 8 schools)
    S = 8                   # number of schools
    N_s = 20  # number of students per school
    mu_0 = 0.0              # prior mean of the global effect
    kappa_0 = 0.1           # prior concentration in the NIX prior
    nu_0 = 0.1              # degrees of freedom for the prior on \tau^2
    tausq_0 = 100.0         # prior mean of the variance \tau^2
    alpha_0 = 0.1           # degrees of freedom for the prior on \sigma_s^2
    sigmasq_0 = 10.0        # scale of the prior on \sigma_s^2

    # Sample data
    key = jr.PRNGKey(13)
    x_bars = jnp.array([28., 8., -3., 7., -1., 1., 18., 12.])
    sigma_bars = jnp.array([15., 10., 16., 11., 9., 11., 10., 18.])
    xs = tfd.Normal(x_bars, jnp.sqrt(N_s) * sigma_bars).sample((N_s,), seed=key)

    # z-score the samples
    zs = (xs - xs.mean(axis=0)) / xs.std(axis=0)

    # Rescale so they have the desired variance
    xs = x_bars + jnp.sqrt(N_s) * sigma_bars * zs

    ## Initialize the Gibbs sampler with a draw from the prior
    nu_N = nu_0 + S + 1
    alpha_sig = alpha_0 + N_s

    def sample_prior(key):
        ## Initialize the Gibbs sampler with a draw from the prior
        key, *skeys = jr.split(key, 5)
        nu_N = nu_0 + S + 1
        tausq = tfd.InverseGamma(0.5 * nu_N, nu_N * tausq_0 * 0.5).sample(seed=skeys[0])
        mu = tfd.Normal(mu_0, jnp.sqrt(tausq / kappa_0)).sample(seed=skeys[1])
        thetas = tfd.Normal(mu, jnp.sqrt(tausq)).sample((S,), seed=skeys[2])
        alpha_sig = alpha_0 + N_s
        sigmasq = tfd.InverseGamma(0.5 * alpha_sig, 0.5 * alpha_sig * sigmasq_0).sample((S,), seed=skeys[3])
        return jnp.concatenate((thetas, sigmasq, jnp.array([mu]), jnp.array([tausq])))

    key, skey = jr.split(in_key)
    skeys = jr.split(skey, (B,))

    initial_state = jax.vmap(sample_prior)(skeys)

    # define gibbs steps
    alpha_N = alpha_0 + N_s
    nu_N = nu_0 + S + 1

    def gibbs_sample_thetas(tausq, mu, sigmasqs, xs, zs):
        # zs are standard normal
        v_theta = 1. / ((N_s / sigmasqs) + (1 / tausq))
        theta_hat = v_theta * ((xs.sum(axis=0) / sigmasqs) + mu / tausq)
        # thetas = Normal(theta_hat, torch.sqrt(v_theta)).sample()
        thetas = theta_hat + jnp.sqrt(v_theta) * zs
        return thetas

    def gibbs_sample_sigmasq(alpha_0, sigmasq_0, thetas, xs, zs):
        # zs are inverse Gamma (0.5* alpha_N, 1)
        sigmasq_N = 1. / alpha_N * (alpha_0 * sigmasq_0
                                + jnp.sum((xs - thetas)**2, axis=0))
        # ScaledInvChiSq(alpha_N, sigmasq_N).sample()
        return 0.5 * alpha_N * sigmasq_N * zs

    def gibbs_sample_mu(mu_0, kappa_0, tausq, thetas, zs):
        # zs are standard normal
        v_mu = tausq / (kappa_0 + S)
        mu_hat = (mu_0 * kappa_0 + thetas.sum()) / (kappa_0 + S)
        # return Normal(mu_hat, torch.sqrt(v_mu)).sample()
        return mu_hat + jnp.sqrt(v_mu) * zs

    def gibbs_sample_tausq(nu_0, tausq_0, mu_0, kappa_0, mu, thetas, zs):
        # zs are inverse Gamma (0.5* nu_N, 1)
        tausq_N = 1. / nu_N * (nu_0 * tausq_0 + kappa_0 * (mu - mu_0)**2
                            + jnp.sum((thetas - mu)**2))
        # return ScaledInvChiSq(nu_N, tausq_N).sample()
        return 0.5 * nu_N * tausq_N * zs

    def fxn_for_deer(state, driver, params):
        thetas, sigmasq, mu, tausq = jnp.split(state, (S, 2*S, 2*S+1))

        key, *skeys = jr.split(driver, 5)
        zs_thetas = jr.normal(skeys[0], (S,))
        zs_sigmasq = tfd.InverseGamma(0.5 * alpha_N, 1).sample((S,), seed=skeys[1])
        zs_mu = jr.normal(skeys[2], (1,))
        zs_tausq = tfd.InverseGamma(0.5 * nu_N, 1).sample((1,), seed=skeys[3])

        tausq = gibbs_sample_tausq(nu_0, tausq_0, mu_0, kappa_0, mu, thetas, zs_tausq)
        mu = gibbs_sample_mu(mu_0, kappa_0, tausq, thetas, zs_mu)
        thetas = gibbs_sample_thetas(tausq, mu, sigmasq, xs, zs_thetas)
        sigmasq = gibbs_sample_sigmasq(alpha_0, sigmasq_0, thetas, xs, zs_sigmasq)

        state = jnp.concatenate((thetas, sigmasq, mu, tausq))

        return state

    # sample
    key = jr.PRNGKey(1313)
    key, skey = jr.split(key)
    drivers = jr.split(key, (B, L))
    params = {'key':jr.PRNGKey(155)}

    initial_state = jax.vmap(fxn_for_deer, in_axes=(0,0,None))(initial_state, drivers[:,0], params)

    # timing
    def fxn_for_scan(state, driver):
        state = fxn_for_deer(state, driver, params)
        return state, state,

    @jax.jit
    def _run_model(initial_state, drivers):
        _, out_states = jax.lax.scan(fxn_for_scan, initial_state, drivers[1:])
        return out_states
    # output_ar = run_model(drivers)
    run_model = jax.jit(jax.vmap(_run_model))
    #out_states = run_model(initial_state, drivers)

    D_dim = initial_state.shape[-1]
    yinit_guess = initial_state[:, None, :] * jnp.ones((B, L-1, D_dim))
    max_deer_iter = 100

    batch_deer = jax.jit(jax.vmap(lambda init_state, drivers, yinit_guess : seq1d(
        fxn_for_deer, init_state, drivers[1:], params, yinit_guess=yinit_guess,
        max_iter=100, quasi=False, qmem_efficient=False, full_trace=False, damp_factor=0.95)))
    outputs_deer = batch_deer(initial_state, drivers, yinit_guess)

    preconditioner = jnp.concatenate((10.0*jnp.ones((8,)), 5000*jnp.ones((8,)), 10*jnp.ones((2,))))
    #preconditioner = None
    batch_qdeer = jax.jit(jax.vmap(lambda init_state, drivers, yinit_guess : seq1d(
        fxn_for_deer, init_state, drivers[1:], params, yinit_guess=yinit_guess,
        max_iter=200, quasi=True, qmem_efficient=True, full_trace=False,
        preconditioner=preconditioner, clip_val=1.0)))
    outputs_qdeer = batch_qdeer(initial_state, drivers, yinit_guess)

    print("qDEER iters: ", outputs_qdeer[-1])
    print("DEER iters: ", outputs_deer[-1])

    del outputs_deer 
    del outputs_qdeer

    # compare timing
    nwarmups=2
    nreps=3
    
    # Warm-up phase
    for _ in range(nwarmups):
        out_states = run_model(initial_state, drivers)
        jax.block_until_ready(out_states)

    # Benchmark phase
    t0 = time.time()
    for _ in range(nreps):
        out_states = run_model(initial_state, drivers)
        jax.block_until_ready(out_states)
    t1 = time.time()
    time_elapsed_seq = (t1 - t0) / nreps
    print(f"sequential time: {time_elapsed_seq:.3e} s")
    del out_states
    
    # Warm-up phase
    for _ in range(nwarmups):
        outputs_qdeer = batch_qdeer(initial_state, drivers, yinit_guess)
        jax.block_until_ready(outputs_qdeer)

    # Benchmark phase
    t0 = time.time()
    for _ in range(nreps):
        outputs_qdeer = batch_qdeer(initial_state, drivers, yinit_guess)
        jax.block_until_ready(outputs_qdeer)
    t1 = time.time()
    time_elapsed_quasi = (t1 - t0) / nreps
    iters_qdeer = outputs_qdeer[-1]
    print(f"quasi deer time: {time_elapsed_quasi:.3e} s")
    del outputs_qdeer

    # Warm-up phase
    for _ in range(nwarmups):
        outputs_deer = batch_deer(initial_state, drivers, yinit_guess)
        jax.block_until_ready(outputs_deer)

    
    # Benchmark phase
    t0 = time.time()
    for _ in range(nreps):
        outputs_deer = batch_deer(initial_state, drivers, yinit_guess)
        jax.block_until_ready(outputs_deer)
    t1 = time.time()
    time_elapsed_deer = (t1 - t0) / nreps
    iters_deer = outputs_deer[-1]
    print(f"deer time: {time_elapsed_deer:.3e} s")
    del outputs_deer
    
    outputs = {'B': B, 'L': L, 'in_key':in_key,
        'time_seq':time_elapsed_seq,
        'time_quasi':time_elapsed_quasi,
        'time_deer':time_elapsed_deer,
        'iters_qdeer':iters_qdeer,
        'iters_deer':iters_deer,
        'nreps':nreps}

    return outputs 


key = jr.PRNGKey(1022)
all_outputs = []

import gc

num_chains = [1,2,4,8,16,32,64]
num_samples = [8000, 16000, 32000, 64000, 128000]
num_runs = 5
for L in num_samples:
    for B in num_chains:
        for i in range(num_runs):
            key, in_key = jr.split(key)
            all_outputs.append(run_gibbs(B, L, in_key))
            gc.collect()
            jax.clear_caches()

num_chains = [1,2,4,8,16,32]
num_samples = [256000, 512000, 512000*2]
for L in num_samples:
    for B in num_chains:
        for i in range(num_runs):
            key, in_key = jr.split(key)
            all_outputs.append(run_gibbs(B, L, in_key))
            gc.collect()
            jax.clear_caches()


