import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, value_and_grad, vmap
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu
from jax.lax import scan
from jax.lib import xla_bridge
from jax.scipy.special import logsumexp
from jax.scipy.stats import norm
import haiku as hk
from utils.bivariate_copula import norm_copula_logdistribution_logdensity
from utils.sampling import resample_is
import wandb
from .copula_ar_base import *
from .copula_ar_test import update_ptest_loop_perm
from jax.random import PRNGKey, permutation, split
import jax.lax as lax
import optax
import math
from tqdm import tqdm
from utils import progress_bar_scan

# Compute bandwidths for each dimension
# if wandb.config.model["diff"] == "eucl":

### Functions to calculate overhead v_{1:n} ###
# Compute v_i for a single datum
@progress_bar_scan(
    wandb.config.data["n_data_points"] * (1 - wandb.config.data["test_size"])
)
@jit
def update_pn(carry, i):
    (
        vn,
        logcdf_conditionals_yn,
        logpdf_joints_yn,
        preq_loglik,
        rho,
        lengths,
        y,
        index,
        helper,
    ) = carry
    n = jnp.shape(logcdf_conditionals_yn)[0]
    d = jnp.shape(logcdf_conditionals_yn)[1]

    # Compute alphas, eq. 4.6
    logalpha = jnp.log(2.0 - (1 / (index + i + 1))) - jnp.log(index + i + 2)

    # Compute independent bandwidths
    rho_d = compute_rho_d(y, y[i], rho, lengths, helper)

    # for each observation, we compute v_i
    # then, we could update u_{1:i} with v_i
    # code is faster if we update u_{1:n} with v_i

    # Compute u,v terms for copula update
    u = jnp.exp(logcdf_conditionals_yn)
    v = jnp.exp(logcdf_conditionals_yn[i])

    # remember history of vn for prediction
    vn = vn.at[i].set(v)

    # update prequential log likelihood
    preq_loglik = preq_loglik.at[i].set(logpdf_joints_yn[i, -2:])
    # only needed when we optimize for rho
    # this is shifted because before update

    # update cdf and pdf
    logcdf_conditionals_yn, logpdf_joints_yn = update_copula(
        logcdf_conditionals_yn, logpdf_joints_yn, u, v, logalpha, rho_d
    )

    carry = (
        vn,
        logcdf_conditionals_yn,
        logpdf_joints_yn,
        preq_loglik,
        rho,
        lengths,
        y,
        index,
        helper,
    )
    return carry, i


# Scan over y_{1:n}, extra function that is global to not pre-compile it everytime
@jit
def update_pn_scan(carry, rng):
    return scan(update_pn, carry, rng)


if wandb.config.data["low_mem"] > 4:

    @jit
    def comp_all_rhos(a, b, c, d):
        return lax.map(lambda x: compute_rho_d(a, x, c, d), b)


else:
    comp_all_rhos = jit(vmap(compute_rho_d, (None, 0, None, None)))

# Compute v_{1:n}
@jit
def update_pn_loop(rho, lengths, y, d_perm_inds, n_perm_inds, helper=None):

    if wandb.config.model["diff"] == "knn":
        y = jnp.take(helper, n_perm_inds, axis=0)
    else:
        y = jnp.take(jnp.take(y, n_perm_inds, axis=0), d_perm_inds, axis=-1)

    n = jnp.shape(n_perm_inds)[-1]
    d = jnp.shape(d_perm_inds)[-1]

    # prequential joint loglik for each d,d-1 (density estimation and regression)
    preq_loglik = jnp.zeros((n, 2))
    # conditional cdf history of xn, no need to differentiate wrt
    vn = jnp.zeros((n, d))

    # initialize cdf/pdf
    logcdf_conditionals_yn, logpdf_joints_yn = init_marginals(y)

    # for loop over i=1:n, update p
    rng = jnp.arange(n)

    # carry out scan over y_{1:n}
    carry = (
        vn,
        logcdf_conditionals_yn,
        logpdf_joints_yn,
        preq_loglik,
        rho,
        lengths,
        y,
        0,
        helper,
    )
    carry, rng = update_pn_scan(carry, rng)

    vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik, *_ = carry

    return vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik


update_pn_loop_perm__ = jit(vmap(update_pn_loop, (None, None, None, None, 0, None)))
update_pn_loop_perm_full = jit(vmap(update_pn_loop_perm__, (0, 0, None, 0, None, None)))


# @jit
def update_pn_loop_from_checkpoint(
    rho,
    lengths,
    y,
    vn,
    logcdf_conditionals_yn,
    logpdf_joints_yn,
    preq_loglik,
    d_perm_inds,
    n_perm_inds,
    index,
    helper=None,
):
    y = jnp.take(y, n_perm_inds, axis=0)
    y = y[:, d_perm_inds]

    # carry out scan over y_{1:n}
    carry = (
        vn,
        logcdf_conditionals_yn,
        logpdf_joints_yn,
        preq_loglik,
        rho,
        lengths,
        y,
        index,
        helper,
    )

    # for loop over i=1:n, update p
    rng = jnp.arange(wandb.config.data["batch_size"])
    carry, rng = update_pn_scan(carry, rng)

    vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik, *_ = carry

    return vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik


update_pn_loop_from_checkpoint_perm_ = jit(
    vmap(
        update_pn_loop_from_checkpoint,
        (None, None, None, 0, 0, 0, 0, None, 0, None, None),
    )
)

update_pn_loop_from_checkpoint_perm = jit(
    vmap(
        update_pn_loop_from_checkpoint_perm_,
        (0, 0, None, 0, 0, 0, 0, 0, None, None, None),
    )
)


def update_pn_from_checkpoint(
    rho,
    lengths,
    y_perm,
    vn_,
    preq_loglik_,
    i,
    vn_perm,
    preq_loglik_perm,
    batch_size,
    vn,
    preq_loglik_save,
    d_perm_inds,
    n_perm_inds,
    helper=None,
):

    a, b, (_, d) = len(d_perm_inds), len(n_perm_inds), y_perm.shape

    vn_perm = vn_perm.at[:, :, i - 1].set(vn_)
    preq_loglik_perm = preq_loglik_perm.at[:, :, i - 1].set(preq_loglik_)

    logcdf_conditionals, logpdf_joints = update_ptest_loop_perm(
        vn_perm[:, :, :i].reshape((a, b, -1, d)),
        rho,
        lengths,
        y_perm,
        vmap(jnp.take, (None, 0, None))(y_perm, n_perm_inds[:, i], 0),
        d_perm_inds,
        n_perm_inds[:, :i].reshape((b, -1)),
        helper,
    )
    vn_, _, _, preq_loglik_ = jit(update_pn_loop_from_checkpoint_perm)(
        rho,
        lengths,
        y_perm,
        vn,
        logcdf_conditionals,
        logpdf_joints,
        preq_loglik_save,
        d_perm_inds,
        n_perm_inds[:, i],
        i * batch_size,
        helper,
    )

    return vn_, preq_loglik_, vn_perm, preq_loglik_perm


def update_pn_loop_perm_batches_tested(
    rho, lengths, y_perm_ori, d_perm_inds, n_perm_inds, helper
):

    a, b, (c, d) = len(d_perm_inds), len(n_perm_inds), y_perm_ori.shape
    n_batched = int(
        (c // wandb.config.data["batch_size"]) * wandb.config.data["batch_size"]
    )
    rem_batched = c % wandb.config.data["batch_size"]

    n_perm_inds_res = jnp.reshape(
        n_perm_inds[:, :n_batched,],
        (b, c // wandb.config.data["batch_size"], wandb.config.data["batch_size"],),
    )
    vn_perm = jnp.zeros(
        (
            a,
            b,
            c // wandb.config.data["batch_size"] + (rem_batched > 0),
            wandb.config.data["batch_size"],
            d,
        ),
    )
    preq_loglik_perm = jnp.zeros_like(vn_perm)[..., :2]

    # carry out scan over y_{1:n}
    vn_, _, _, preq_loglik_ = update_pn_loop_perm_full(
        rho, lengths, y_perm_ori, d_perm_inds, n_perm_inds_res[:, 0], helper,
    )
    vn = jnp.zeros_like(vn_)
    preq_loglik_save = vn[..., :2]

    for i in tqdm(range(1, c // wandb.config.data["batch_size"])):

        vn_, preq_loglik_, vn_perm, preq_loglik_perm = update_pn_from_checkpoint(
            rho,
            lengths,
            y_perm_ori,
            vn_,
            preq_loglik_,
            i,
            vn_perm,
            preq_loglik_perm,
            wandb.config.data["batch_size"],
            vn,
            preq_loglik_save,
            d_perm_inds,
            n_perm_inds_res,
            helper,
        )

    vn_perm = vn_perm.at[:, :, -1 - (rem_batched > 0)].set(vn_)
    preq_loglik_perm = preq_loglik_perm.at[:, :, -1 - (rem_batched > 0)].set(
        preq_loglik_
    )
    vn_perm = vn_perm.reshape((a, b, -1, d))[:, :, :c]
    preq_loglik_perm = preq_loglik_perm.reshape((a, b, -1, 2))[:, :, :c]

    if rem_batched:
        logcdf_conditionals, logpdf_joints = update_ptest_loop_perm(
            vn_perm[:, :, :n_batched],
            rho,
            lengths,
            y_perm_ori,
            vmap(jnp.take, (None, 0))(y_perm_ori, n_perm_inds[:, n_batched:], 0),
            d_perm_inds,
            n_perm_inds[:, :n_batched],
            helper,
        )
        vn_, _, _, preq_loglik_ = update_pn_loop_from_checkpoint_perm(
            rho,
            lengths,
            y_perm_ori,
            vn[:, :, :rem_batched],
            logcdf_conditionals,
            logpdf_joints,
            preq_loglik_save[:, :, :rem_batched],
            n_batched,
            d_perm_inds,
            n_perm_inds[:, n_batched:],
            helper,
        )
        vn_perm = vn_perm.at[:, :, n_batched:,].set(vn_)
        preq_loglik_perm = preq_loglik_perm.at[:, :, n_batched:,].set(preq_loglik_)

    return (
        vn_perm,
        None,
        None,
        preq_loglik_perm,
    )


if wandb.config.data["low_mem"] > 1:

    @jit
    def update_pn_loop_perm__(rho, lengths, y_perm, d_perm_inds, n_perm_inds, helper):
        return lax.map(
            lambda x: update_pn_loop(rho, lengths, y_perm, d_perm_inds, x, helper),
            (n_perm_inds),
        )


if wandb.config.data["low_mem"] > 0:

    @jit
    def update_pn_loop_perm(rho, lengths, y_perm, d_perm_inds, n_perm_inds, helper):
        return lax.map(
            lambda x: update_pn_loop_perm__(
                x[0], x[1], y_perm, x[2], n_perm_inds, helper
            ),
            (rho, lengths, d_perm_inds),
        )


else:
    update_pn_loop_perm = update_pn_loop_perm_full

    # region GPUS
    # if (
    #     len(jax.devices()) == 1
    #     or xla_bridge.get_backend().platform == "cpu"
    #     or len(jax.devices()) == 4
    # ):
    #     update_pn_loop_perm_full = update_pn_loop_perm_
    # else:
    #     # update_pn_loop_perm = update_pn_loop_perm_

    #     @jit
    #     def update_pn_loop_perm(rho, lengths, y_perm):
    #         vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik = jax.pmap(
    #             update_pn_loop_perm_, axis_name="i", in_axes=(None, None, 0)
    #         )(
    #             rho,
    #             lengths,
    #             y_perm.reshape(
    #                 (len(jax.devices()), y_perm.shape[0] // len(jax.devices()), -1)
    #             ),
    #         )
    #         # for k in vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik:
    #         #     k = k.reshape(y_perm.shape)

    #         return vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik
    # endregion
