import numpy as np
import numpyro
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC, init_to_mean
from jax import random, lax, vmap, jit
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import pyreadr
from sklearn.preprocessing import LabelEncoder
from numpyro.diagnostics import summary, print_summary
from distribution import MarginalizedMultivariateNormalGroupCoeff
import click
import os
from time import time
from jax.lax import scan
import jax
jax.config.update("jax_enable_x64", True)

def model(n_sub, n_item, n_obs, g1, g2, treatment, obs):
    alpha = numpyro.sample('alpha', dist.Normal(0, 10))
    beta = numpyro.sample('beta', dist.Normal(0, 10))
    sigma = numpyro.sample('sigma', dist.HalfNormal(50))
    sigma_u = numpyro.sample('sigma_u', dist.LKJCholesky(2))
    tau_u = numpyro.sample('tau_u', dist.HalfNormal(20), sample_shape=(2, ))
    sigma_w = numpyro.sample('sigma_w', dist.LKJCholesky(2))
    tau_w = numpyro.sample('tau_w', dist.HalfNormal(20), sample_shape=(2, ))
    s_u = jnp.matmul(jnp.diag(tau_u), sigma_u)
    s_w = jnp.matmul(jnp.diag(tau_w), sigma_w)
    w = jnp.zeros((n_item, 2))
    u = numpyro.sample('u', dist.MultivariateNormal(jnp.zeros((2,)), scale_tril=s_u), sample_shape=(n_sub,))
    #w = numpyro.sample('v', dist.MultivariateNormal(jnp.zeros((2,)),scale_tril=s_w), sample_shape=(n_item,))
    numpyro.sample('y', MarginalizedMultivariateNormalGroupCoeff(alpha + u[g1][...,0] + w[g2][...,0] + treatment * (beta + u[g1][...,1] + w[g2][...,1]), s_w, sigma, g2, treatment, n_item, n_obs), obs=obs)

@click.command()
@click.option('--rng_key', default=0,)
@click.option('--warm_up_steps', default=10000, help = 'Number of warm up samples in HMC')
@click.option('--sample_steps', default=100000, help = 'Number of samples in HMC')
def main(rng_key, warm_up_steps, sample_steps):
    result = pyreadr.read_r('data/rdata/df_eeg_complete.rda')
    data = result['df_eeg_complete']
    subj_encoder = LabelEncoder()
    data["subj"] = np.array(subj_encoder.fit_transform(data["subj"].values))
    item_encoder = LabelEncoder()
    data["item"] = np.array(item_encoder.fit_transform(data["item"].values))

    np.random.seed(rng_key)
    n_obs = len(data)
    n_sub = len(subj_encoder.classes_)
    n_item = len(item_encoder.classes_)
    g1 = data["subj"].values
    g2 = data["item"].values
    print(n_sub, n_item)

    obs = data["n400"].astype(float).values
    treatment = data["cloze"].astype(float).values
    start_time = time()
    nuts_kernel = NUTS(model, init_strategy=init_to_mean)
    mcmc = MCMC(nuts_kernel, num_warmup=warm_up_steps, num_samples=sample_steps)
    sample_key, recover_key = random.split(random.PRNGKey(rng_key))
    mcmc.run(sample_key, n_sub, n_item, n_obs, g1, g2, treatment, obs)
    sample = mcmc.get_samples()

    alpha = sample['alpha']
    beta = sample['beta']
    sigma = sample['sigma']
    sigma_u = sample['sigma_u']
    tau_u = sample['tau_u']
    tau_w = sample['tau_w']
    u = sample['u']
    sigma_w = sample['sigma_w']

    def recover(_, tup):
        alpha, beta, sigma, sigma_u, tau_u, sigma_w, tau_w, u, key = tup
        w = jnp.zeros((n_item, 2))
        s_w = jnp.matmul(jnp.diag(tau_w), sigma_w)
        return None, MarginalizedMultivariateNormalGroupCoeff(
            alpha + u[g1][..., 0] + w[g2][..., 0] + treatment * (beta + u[g1][..., 1] + w[g2][..., 1]),
            s_w, sigma, g2, treatment, n_item, n_obs, w).sample_x(obs, key)

    keys = random.split(recover_key, len(u))
    _, sample['w'] = scan(jit(recover), None, (alpha, beta, sigma, sigma_u, tau_u, sigma_w, tau_w, u, keys))
    end_time = time()
    all_time = end_time - start_time
    to_eval = {}
    for key, val in sample.items():
        to_eval[key] = np.array([val])
    sum = summary(to_eval, prob = 0.9)
    print_summary(to_eval, prob = 0.9)

    all_variables = sum.keys()
    sorted(all_variables)
    for key in all_variables:
        s = sum[key]['n_eff']
        s = s[~np.isnan(s)]
        print(key, np.mean(s), np.min(s), np.mean(s)/all_time,
              np.min(s)/all_time)
    extra_fields = mcmc.get_extra_fields()
    if "diverging" in extra_fields:
        print(
            "Number of divergences: {}".format(jnp.sum(extra_fields["diverging"]))
        )

    PATH = f'result/eeg/{sample_steps}/M2'
    os.makedirs(PATH, exist_ok=True)
    output_file = f'{PATH}/{rng_key}'
    np.savez_compressed(output_file+'.npz', sample=sample)

    table = ['alpha','beta','sigma','sigma_u','tau_u','sigma_w','tau_w','u','w']
    with open(output_file,'w') as f:
        print(all_time, jnp.sum(extra_fields["diverging"]), file=f)
        for key in table:
            s = sum[key]['n_eff']
            s = s[~np.isnan(s)]
            print(np.mean(s), end=' ',file=f)


if __name__ == '__main__':
    main()