import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, value_and_grad, vmap
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu
from jax.lax import scan
from jax.lib import xla_bridge
from jax.scipy.special import logsumexp
from jax.scipy.stats import norm
import haiku as hk
from utils.bivariate_copula import norm_copula_logdistribution_logdensity
from utils.sampling import resample_is
import wandb
from .copula_ar_base import *
from .copula_ar_optim import slice_lengths
from jax.random import PRNGKey, permutation, split

# region Functions for sampling from p(y)
# Update p(y) for a single test point and observed datum

update_copula_nperm_map = jit(
    vmap(update_copula, (0, 0, 0, 0, None, 0))
)  # vmap over vn_perm
update_copula_dperm_map = jit(
    vmap(update_copula_nperm_map, (0, 0, 0, 0, None, 0))
)  # vmap over vn_perm


def compute_rho_d_perm(y_test, y, rho, lengths, d_perm_ind):
    return compute_rho_d(
        y_test[..., d_perm_ind], y[..., d_perm_ind], rho, lengths, None,
    )


compute_rho_d_nperm_map = jit(vmap(compute_rho_d_perm, (None, 0, None, None, None)))
compute_rho_d_dperm_map = jit(vmap(compute_rho_d_nperm_map, (None, None, 0, 0, 0)))


@jit
def sample_p_single_loop_ytest_permuting(carry, i):
    (
        vn,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y,
        y_test,
        logweights,
        d_perm_inds,
        n_perm_inds,
        key,
        n_resampl,
    ) = carry

    n_perm, d_perm = len(n_perm_inds), len(d_perm_inds)
    logpdf_joints_ytest_old = logpdf_joints_ytest

    # Compute alpha and bandwidths
    logalpha = jnp.log(2.0 - (1 / (i + 1))) - jnp.log(i + 2)
    rho_d = compute_rho_d_dperm_map(
        y_test, y[n_perm_inds[:, i]], rho, lengths, d_perm_inds,
    )

    # Compute u,v for updating p(y)
    u = jnp.exp(logcdf_conditionals_ytest)  # size of number of test points
    v = vn[:, :, i]  # now you do not compute v

    # Update p(y) and P(y)
    (logcdf_conditionals_ytest, logpdf_joints_ytest,) = update_copula_dperm_map(
        logcdf_conditionals_ytest, logpdf_joints_ytest, u, v, logalpha, rho_d
    )

    logpdf_joints_ave = logsumexp(logpdf_joints_ytest, axis=0) - jnp.log(d_perm)
    logpdf_joints_ave_old = logsumexp(logpdf_joints_ytest_old, axis=0) - jnp.log(d_perm)
    logpdf_joints_ave = logsumexp(logpdf_joints_ave, axis=0) - jnp.log(n_perm)
    logpdf_joints_ave_old = logsumexp(logpdf_joints_ave_old, axis=0) - jnp.log(n_perm)

    #! sampling
    logweights += logpdf_joints_ytest[..., -1] - logpdf_joints_ytest_old[..., -1]
    # logweights += logpdf_joints_ave[..., -1] - logpdf_joints_ave_old[..., -1]
    logweights, ind_new, ESS = resample_is(logweights, key[i])
    carry = (
        vn,
        logcdf_conditionals_ytest[:, :, ind_new],
        logpdf_joints_ytest[:, :, ind_new],
        rho,
        lengths,
        y,
        y_test[ind_new],
        logweights,
        d_perm_inds,
        n_perm_inds,
        key,
        n_resampl + 1,
    )

    return carry, i


@jit
def sample_p_single_loop_ytest(carry, i):
    (
        vn,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y,
        y_test,
        logweights,
        d_perm_inds,
        n_perm_inds,
        key,
        n_resampl,
        ess,
        helper,
    ) = carry

    n_perm, d_perm = len(n_perm_inds), len(d_perm_inds)
    logpdf_joints_ytest_old = logpdf_joints_ytest

    # Compute alpha and bandwidths
    logalpha = jnp.log(2.0 - (1 / (i + 1))) - jnp.log(i + 2)
    rho_d = compute_rho_d(y_test, y[n_perm_inds[i]][d_perm_inds], rho, lengths, helper)

    # Compute u,v for updating p(y)
    u = jnp.exp(logcdf_conditionals_ytest)  # size of number of test points
    v = vn[i]  # now you do not compute v

    # Update p(y) and P(y)
    (logcdf_conditionals_ytest, logpdf_joints_ytest,) = update_copula(
        logcdf_conditionals_ytest, logpdf_joints_ytest, u, v, logalpha, rho_d
    )

    logweights += logpdf_joints_ytest[..., -1] - logpdf_joints_ytest_old[..., -1]

    logweights, ind_new, ESS = resample_is(logweights, key[i])
    ess = ess.at[i].set(ESS)

    carry = (
        vn,
        logcdf_conditionals_ytest[ind_new],
        logpdf_joints_ytest[ind_new],
        rho,
        lengths,
        y,
        y_test[ind_new],
        logweights,
        d_perm_inds,
        n_perm_inds,
        key,
        n_resampl + 1,
        ess,
        helper,
    )

    return carry, i


# Scan over n observed data
@jit
def sample_p_loop_ytest_scan(carry, rng):
    return scan(sample_p_single_loop_ytest_permuting, carry, rng)


@jit
def sample_p_loop_perm_loop_ytest(
    vn_perm, rho, lengths, y_perm, y_test, d_perm_inds, n_perm_inds, seed, helper,
):
    n = jnp.shape(vn_perm)[-2]
    d_perm, n_perm = len(d_perm_inds), len(n_perm_inds)
    key, *key_seq = split(PRNGKey(seed), len(y_perm))
    key_seq = jnp.array(key_seq)

    # Initialize p0
    logcdf_conditionals_ytest, logpdf_joints_ytest = jnp.tile(
        init_marginals_perm_(y_test[:, d_perm_inds])[None], (n_perm, 1, 1)
    )
    logweights = jnp.log(np.ones(len(y_test)))

    # Carry out scan over y_{1:n}
    carry = (
        vn_perm,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y_perm,
        y_test,
        logweights,
        d_perm_inds,
        n_perm_inds,
        key_seq,
        0,
        jnp.zeros((len(y_test))),
        helper,
    )
    rng = jnp.arange(n)
    carry, rng = sample_p_loop_ytest_scan(carry, rng)
    (
        vn,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y,
        y_test,
        logweights,
        d_perm_inds,
        n_perm_inds,
        key_seq,
        n_resampl,
        ess,
        helper,
    ) = carry

    return y_test, logpdf_joints_ytest, n_resampl, ess


# * we have to average over test observations way earlier
sample_p_loop_perm_ = jit(
    vmap(sample_p_loop_perm_loop_ytest, (0, None, None, None, 0, None, 0, None, None,),)
)
sample_p_loop_perm = jit(
    vmap(sample_p_loop_perm_, (0, 0, 0, None, 0, 0, None, None, None))
)


@jit
def sample_p_loop_perm(
    vn_perm, rho, lengths, y_perm, y_test, d_perm_inds, n_perm_inds, seed, helper,
):
    n = jnp.shape(vn_perm)[-2]
    d_perm, n_perm = len(d_perm_inds), len(n_perm_inds)
    key, *key_seq = split(PRNGKey(seed), len(y_perm))
    key_seq = jnp.array(key_seq)

    # Initialize p0
    logcdf_conditionals_ytest, logpdf_joints_ytest = init_marginals(y_test)
    logweights = jnp.log(np.ones(len(y_test)))

    # Carry out scan over y_{1:n}
    carry = (
        vn_perm,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y_perm,
        y_test,
        logweights,
        d_perm_inds,
        n_perm_inds,
        key_seq,
        0,
        jnp.zeros((len(y_test))),
        helper,
    )
    rng = jnp.arange(n)
    carry, rng = sample_p_loop_ytest_scan(carry, rng)
    (
        vn,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y,
        y_test,
        logweights,
        d_perm_inds,
        n_perm_inds,
        key_seq,
        n_resampl,
        ess,
        helper,
    ) = carry

    return y_test, logpdf_joints_ytest, n_resampl, ess


# endregion
