from unittest.mock import NonCallableMagicMock
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, value_and_grad, vmap
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu
from jax.lax import scan
from jax.lib import xla_bridge
from jax.scipy.special import logsumexp
from jax.scipy.stats import norm
import haiku as hk
from utils.bivariate_copula import norm_copula_logdistribution_logdensity
from utils.sampling import resample_is
import wandb
from .copula_ar_base import *
from utils import progress_bar_scan

# region Functions for computing p(y) on test points $$$
# Update p(y) for a single test point and observed datum
@progress_bar_scan(wandb.config.data["n_data_points"] * wandb.config.data["test_size"])
@jit
def update_ptest_single(carry, i):
    (
        vn,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y,
        y_test,
        index,
        helper,
    ) = carry

    # Compute alpha and bandwidths
    logalpha = jnp.log(2.0 - (1 / (i + index + 1))) - jnp.log(i + index + 2)
    rho_d = compute_rho_d_single(y_test, y[i], rho, lengths, helper)

    # Compute u,v for updating p(y)
    u = jnp.exp(logcdf_conditionals_ytest)  # size of number of test points
    v = vn[i]  # now you do not compute v

    # Update p(y) and P(y)
    logcdf_conditionals_ytest, logpdf_joints_ytest = update_copula_single(
        logcdf_conditionals_ytest, logpdf_joints_ytest, u, v, logalpha, rho_d
    )

    carry = (
        vn,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y,
        y_test,
        index,
        helper,
    )
    return carry, i


# Scan over n observed data
@jit
def update_ptest_single_scan(carry, rng):
    return scan(update_ptest_single, carry, rng)


# Compute p(y) for a single test point and y_{1:n}
@jit
def update_ptest_single_loop_from_checkpoint(
    vn,
    rho,
    lengths,
    y,
    y_test,
    logcdf_conditionals_ytest,
    logpdf_joints_ytest,
    index=0,
    helper=None,
):
    n = jnp.shape(vn)[0]

    # Carry out scan over y_{1:n}
    carry = (
        vn,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y,
        y_test,
        index,
        helper,
    )
    rng = jnp.arange(n)
    carry, rng = update_ptest_single_scan(carry, rng)
    vn, logcdf_conditionals_ytest, logpdf_joints_ytest, *_ = carry

    return logcdf_conditionals_ytest, logpdf_joints_ytest


# Compute p(y) for a single test point and y_{1:n}
@jit
def update_ptest_single_loop(
    vn, rho, lengths, y, y_test, d_perm_inds, n_perm_inds, index=0, helper=None,
):

    y_test = y_test[..., d_perm_inds]
    if wandb.config.model["diff"] == "knn":
        y_test = index
    else:
        y = y[..., d_perm_inds]
    y = jnp.take(y, n_perm_inds, axis=0)

    n = jnp.shape(vn)[0]

    # Initialize p0
    logcdf_conditionals_ytest, logpdf_joints_ytest = init_marginals_single(y_test)

    # Carry out scan over y_{1:n}
    carry = (
        vn,
        logcdf_conditionals_ytest,
        logpdf_joints_ytest,
        rho,
        lengths,
        y,
        y_test,
        0,
        helper,
    )
    rng = jnp.arange(n)
    carry, rng = update_ptest_single_scan(carry, rng)
    vn, logcdf_conditionals_ytest, logpdf_joints_ytest, *_ = carry

    return logcdf_conditionals_ytest, logpdf_joints_ytest


update_ptest_single_loop_perm_ = jit(
    vmap(update_ptest_single_loop, (0, None, None, None, None, None, 0, None, None))
)  # vmap over vn_perm
update_ptest_single_loop_perm = jit(
    vmap(update_ptest_single_loop_perm_, (0, 0, 0, None, None, 0, None, None, None))
)  # vmap over vn_perm


# Average p(y) over permutations for single test point
@jit
def update_ptest_single_loop_perm_av(
    vn_perm,
    rho,
    lengths,
    y_perm,
    y_test,
    d_perm_inds,
    n_perm_inds,
    index=0,
    helper=None,
):
    d_perm = len(d_perm_inds)
    n_perm = len(n_perm_inds)
    logcdf_conditionals, logpdf_joints = update_ptest_single_loop_perm(
        vn_perm, rho, lengths, y_perm, y_test, d_perm_inds, n_perm_inds, index, helper,
    )

    logcdf_conditionals = logsumexp(logcdf_conditionals, axis=1) - jnp.log(n_perm)
    logpdf_joints = logsumexp(logpdf_joints, axis=1) - jnp.log(n_perm)

    logcdf_conditionals = logsumexp(logcdf_conditionals, axis=0) - jnp.log(d_perm)
    logpdf_joints = logsumexp(logpdf_joints, axis=0) - jnp.log(d_perm)

    return logcdf_conditionals, logpdf_joints


# Vmap over multiple test points
if wandb.config.model["diff"] == "knn":
    update_ptest_loop_perm_av = jit(
        vmap(
            update_ptest_single_loop_perm_av,
            (None, None, None, None, -2, None, None, 0, None),
        )
    )  # vmap over vn_perm
else:
    update_ptest_loop_perm_av = jit(
        vmap(
            update_ptest_single_loop_perm_av,
            (None, None, None, None, -2, None, None, None, None),
        )
    )
    update_ptest_loop_perm_per_perm = jit(
        vmap(
            update_ptest_single_loop_perm,
            (None, None, None, None, -2, None, None, None, None),
        )
    )
# endregion


## region Fonctuions for updating, i.e. average over y_test perm as well!
# Vmap over multiple test points
if wandb.config.model["diff"] == "knn":
    update_ptest_loop = jit(
        vmap(update_ptest_single_loop, (None, None, None, None, -2, None, None, 0))
    )  # vmap over vn_perm
else:
    update_ptest_loop = jit(
        vmap(update_ptest_single_loop, (None, None, None, None, -2, None, None))
    )

update_ptest_loop_perm_ = jit(
    vmap(update_ptest_loop, (0, None, None, None, 0, None, 0))
)  # vmap over vn_perm
update_ptest_loop_perm = jit(
    vmap(update_ptest_loop_perm_, (0, 0, 0, None, None, 0, None))
)
# endregion
