import time
from collections import namedtuple
import pickle as pkl
import os

import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap
from jax.random import PRNGKey, permutation, split
from scipy.optimize import minimize
from sklearn.decomposition import non_negative_factorization
from tqdm import tqdm
import haiku as hk
import optax
from typing import Iterator, Mapping, Tuple
import wandb
import math
import utils
from utils.helper import EarlyStopping
from sklearn.cluster import DBSCAN
from sklearn.neighbors import KDTree

Batch = Mapping[str, np.ndarray]

# todo: conditional imputation: SMC with posterior correction
# todo: permute over feature orderings
# todo: octane
# todo: bernoulli density estimation

# import os

# os.environ["JAX_CHECK_TRACER_LEAKS"] = "True"
# jax.config.update("jax_debug_nans", True)
if wandb.config.base["class"] or wandb.config.base["regress"]:  # :
    jax.config.update("jax_enable_x64", True)
if wandb.config.base["disable_jit"]:
    jax.config.update("jax_disable_jit", True)


# Compute overhead v_{1:n}, return fit copula object for prediction
def fit_copula_density(
    y,
    y_val,
    n_perm=10,
    d_perm=5,
    n_perm_optim=None,
    n_optim=None,
    maxiter=3,
    bern=False,
    init_rho=0.9,
    init_length=1.0,
    seed=20,
    opt_path=None,
    bounds=None,
):

    print(jax.devices())
    hyperparam_d = jnp.shape(y)[-1]
    helper = None
    opt_fun = 999

    # Set seed for scipy
    np.random.seed(seed)

    # region imports

    if wandb.config.base["class"]:
        import models.copula_AR_functions.copula_classification_functions_old as cop_AR

        if wandb.config.data["batching"]:
            from models.copula_AR_functions.copula_classification_functions_old import (
                update_pn_loop_perm_batches_tested as update_pn_loop_perm,
            )
        else:
            from models.copula_AR_functions.copula_classification_functions_old import (
                update_pn_loop_perm,
            )

        if wandb.config.data["batching_optim"]:
            from models.copula_AR_functions.copula_classification_functions_old import (
                fit_batches as optimize_params,
            )
        else:
            from models.copula_AR_functions.copula_classification_functions_old import (
                fit as optimize_params,
            )
    elif wandb.config.base["regress"]:
        import models.copula_AR_functions.copula_regression_functions_old as cop_AR

        if wandb.config.data["batching"]:
            from models.copula_AR_functions.copula_regression_functions_old import (
                update_pn_loop_perm_batches_tested as update_pn_loop_perm,
            )
        else:
            from models.copula_AR_functions.copula_regression_functions_old import (
                update_pn_loop_perm,
            )

        if wandb.config.data["batching_optim"]:
            from models.copula_AR_functions.copula_regression_functions_old import (
                fit_batches as optimize_params,
            )
        else:
            from models.copula_AR_functions.copula_regression_functions_old import (
                fit as optimize_params,
            )
    else:
        import models.copula_AR_functions as cop_AR

        if wandb.config.data["batching"]:
            from models.copula_AR_functions import (
                update_pn_loop_perm_batches_tested as update_pn_loop_perm,
            )
        else:
            from models.copula_AR_functions import update_pn_loop_perm

        if wandb.config.data["batching_optim"]:
            from models.copula_AR_functions import fit_batches as optimize_params
        else:
            from models.copula_AR_functions import fit as optimize_params

    # endregion

    # region generate random permutations
    key = PRNGKey(seed)

    key, *subkey = split(key, n_perm + 1)
    subkey = jnp.array(subkey)
    n_perm_inds = vmap(permutation, (0, None))(subkey, jnp.arange(len(y)))

    key, *subkey = split(key, d_perm + 1)
    subkey = jnp.array(subkey)
    if d_perm > 1:
        if wandb.config.base["class"] or wandb.config.base["regress"]:
            d_perm_inds = vmap(permutation, (0, None))(
                subkey, jnp.arange(y.shape[1] - 1)
            )
            d_perm_inds = jnp.concatenate(
                [d_perm_inds, jnp.ones((d_perm, 1)) * y.shape[1] - 1], 1
            ).astype(int)
        else:
            d_perm_inds = vmap(permutation, (0, None))(subkey, jnp.arange(y.shape[1]))
    else:
        d_perm_inds = jnp.arange(y.shape[1])[None]
    # n_perm_inds = jnp.arange(y.shape[0])[None]

    key_seq = hk.PRNGSequence(seed + 1)
    # endregion

    y_opt = y

    # Initialize parameter and put on correct scale to lie in [0,1]/[0,\infty]
    if wandb.config.model["diff"] in [
        "net",
        "joint_net",
        "joint_net_zeroed",
        "arnet",
        "net-dim",
        "arnet-dim",
    ]:
        from models.copula_AR_functions import network

        # Compiling
        start = time.time()

        # region init params
        key, *subkey = split(key, d_perm + 1)
        net_params = vmap(network.init, (0, None),)(
            jnp.array(subkey),
            jnp.zeros(
                [
                    1,
                    (y.shape[-1])
                    * (
                        1
                        + (
                            wandb.config.model["diff"]
                            in ["joint_net", "joint_net_zeroed"]
                        )
                    ),
                ]
            ),
        )
        if "dim" not in wandb.config.model["diff"]:
            rho_lengths_init = jnp.ones((d_perm)) * init_rho
        else:
            rho_lengths_init = jnp.ones((d_perm, hyperparam_d)) * init_rho
        params = {"rho": jnp.log(1 / rho_lengths_init - 1), **net_params}
        if wandb.config.model["diff"] == "arnet":
            helper = utils.generate_ar_masks(net_params)
        elif wandb.config.model["diff"] == "net":
            helper = (None, jnp.tri(wandb.config.data["d"]))
        # endregion

        # region compiling
        # prequential likelihood
        if wandb.config.base["compile"]:
            print("Compiling fun...")
            temp = cop_AR.negpreq_jointloglik_perm(
                params, y_opt, d_perm_inds, n_perm_inds
            )  # convert to numpy array for scipy.optimize

            # gradient of that
            print("Compiling grad...")
            temp = cop_AR.grad_jll_perm(
                {
                    "rho": params["rho"][:1],
                    **{
                        k: {k_inner: p_inner[:1] for k_inner, p_inner in p.items()}
                        for k, p in params.items()
                        if k != "rho"
                    },
                },
                y_opt,
                d_perm_inds[:1],
                n_perm_inds[:n_perm_optim, :n_optim],
                helper,
            )

            # fit the actual copula (sth with precompiling)
            print("Compiling update...")
            temp = update_pn_loop_perm(
                params["rho"],
                {k: v for k, v in params.items() if k != "rho"},
                y_opt,
                d_perm_inds,
                n_perm_inds,
                helper,
            )[0].block_until_ready()

            end = time.time()
            print("Compilation time: {}s".format(round(end - start, 3)))
            wandb.log({"compilation_time": round(end - start, 3)})
        # endregion

        # Minimise the loss

        if not wandb.config.base["reload"] and os.path.exists(opt_path):

            with open(opt_path, "rb") as f:
                rho_lengths_opt = pkl.load(f)

        else:
            # Optimize with Adam
            print("Optimizing...")
            start = time.time()

            for i in range(d_perm):
                optimizer = optax.adamw(learning_rate=1e-2)
                print("######### Permutation {} #########".format(i))
                early_stopping_wrapper = EarlyStopping(
                    copula_test_nll,
                    patience=wandb.config.model["patience"],
                    y_val=y_val[..., d_perm_inds[i]],
                    miniter=wandb.config.model["miniter"],
                    d_perm_i=i,
                    helper=helper,
                    d_perm_inds=jnp.arange(hyperparam_d)[None],
                    n_perm_inds=jnp.arange(len(y_opt))[None],
                )

                params_i, opt_fun = optimize_params(
                    y_opt[..., d_perm_inds[i]],
                    {
                        "rho": params["rho"][i : i + 1],
                        **{
                            k: {
                                k_inner: p_inner[i : i + 1]
                                for k_inner, p_inner in p.items()
                            }
                            for k, p in params.items()
                            if k != "rho"
                        },
                    },
                    optimizer,
                    key_seq,
                    n_optim,
                    maxiter,
                    early_stopping_wrapper,
                    helper=helper,
                )
                params = {
                    "rho": params["rho"].at[i : i + 1].set(params_i["rho"]),
                    **{
                        k: {
                            k_inner: p_inner.at[i : i + 1].set(params_i[k][k_inner])
                            for k_inner, p_inner in p.items()
                        }
                        for k, p in params.items()
                        if k != "rho"
                    },
                }

            rho_lengths_opt = params.copy()
            # unscale hyperparameter
            rho_lengths_opt["rho"] = 1 / (1 + jnp.exp(rho_lengths_opt["rho"]))
            if wandb.config.model["diff"] == "arnet":
                rho_lengths_opt = {
                    k: {"w": v["w"] * helper[k], "b": v["b"]} if k != "rho" else v
                    for k, v in rho_lengths_opt.items()
                }

            end = time.time()
            wandb.log({"rho": rho_lengths_opt["rho"]})
            print("Optimization time: {}s".format(round(end - start, 3)))
            wandb.log({"optimization_time": round(end - start, 3)})

            with open(opt_path, "wb") as f:
                pkl.dump(rho_lengths_opt, f)

    elif wandb.config.model["diff"] in [
        "eucl",
        "none",
        "extreme",
        "dbscan",
        "knn",
        "dim",
        "eucl-dim",
    ]:

        # init hyperparam
        if wandb.config.model["diff"] == "dbscan":
            db = DBSCAN(eps=0.01, min_samples=10).fit(y)
            core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
            core_samples_mask[db.core_sample_indices_] = True
            labels = db.labels_

            # Number of clusters in labels, ignoring noise if present.
            n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
            n_noise_ = list(labels).count(-1)

            print("Estimated number of clusters: %d" % n_clusters_)
            print("Estimated number of noise points: %d" % n_noise_)

            # d = jnp.shape(y)[1]
            rho_lengths_init = (
                jnp.ones((d_perm, 1)) * init_rho
            )  # first dimension is bandwidth, rest is d-1 length scales

            hyperparam_init = jnp.log(1 / rho_lengths_init - 1)
        else:
            if wandb.config.model["diff"] == "none" and wandb.config.base["class"]:
                rho_lengths_init = (
                    jnp.ones((d_perm, 2)) * init_rho
                )  # first dimension is bandwidth, rest is d-1 length scales

                hyperparam_init = jnp.log(1 / rho_lengths_init - 1)
            elif wandb.config.model["diff"] == "dim":
                rho_lengths_init = (
                    jnp.ones((d_perm, hyperparam_d)) * init_rho
                )  # first dimension is bandwidth, rest is d-1 length scales

                hyperparam_init = jnp.log(1 / rho_lengths_init - 1)
            elif wandb.config.model["diff"] == "eucl-dim":
                rho_lengths_init = jnp.ones(
                    (d_perm, 2 * hyperparam_d - 1)
                )  # first dimension is bandwidth, rest is d-1 length scales
                rho_lengths_init = rho_lengths_init.at[:, :hyperparam_d].set(init_rho)
                lengthscales = jnp.log(rho_lengths_init[:, hyperparam_d:] * init_length)
                hyperparam_init = rho_lengths_init.at[:, :hyperparam_d].set(
                    jnp.log(1 / rho_lengths_init[:, :hyperparam_d] - 1)
                )
                hyperparam_init = hyperparam_init.at[:, hyperparam_d:].set(lengthscales)
            else:
                rho_lengths_init = jnp.ones(
                    (d_perm, hyperparam_d)
                )  # first dimension is bandwidth, rest is d-1 length scales
                rho_lengths_init = rho_lengths_init.at[:, 0].set(init_rho)
                lengthscales = jnp.log(rho_lengths_init[:, 1:] * init_length)
                hyperparam_init = rho_lengths_init.at[:, 0].set(
                    jnp.log(1 / rho_lengths_init[:, 0] - 1)
                )
                hyperparam_init = hyperparam_init.at[:, 1:].set(lengthscales)

        if wandb.config.model["diff"] == "knn":
            tree = KDTree(y)
            helper = tree.query(
                y,
                k=wandb.config.model["n_knn"],
                return_distance=False,
                sort_results=False,
            )
        else:
            helper = None

        if wandb.config.base["compile"]:
            print("Compiling...")
            start = time.time()

            # prequential likelihood
            print("Compiling fun...")
            temp = cop_AR.negpreq_jointloglik_perm(
                hyperparam_init[:1],
                y_opt,
                d_perm_inds[:1],
                n_perm_inds[:n_perm_optim, :n_optim],
                helper,
            )  # convert to numpy array for scipy.optimize

            # gradient of that
            print("Compiling grad...")
            temp = cop_AR.grad_jll_perm(
                hyperparam_init[:1],
                y_opt,
                d_perm_inds[:1],
                n_perm_inds[:n_perm_optim, :n_optim],
                helper,
            )

            # fit the actual copula (sth with precompiling)
            print("Compiling update...")
            rho, lengths = cop_AR.get_rho_params_wo_transform(rho_lengths_init)
            temp = update_pn_loop_perm(
                rho, lengths, y, d_perm_inds, n_perm_inds, helper,
            )[0].block_until_ready()

            end = time.time()
            print("Compilation time: {}s".format(round(end - start, 3)))
            wandb.log({"compilation_time": round(end - start, 3)})

        # Optimize with SLSQP (Quasi-Newton)
        print("Optimizing...")
        start = time.time()

        if not maxiter:
            rho_lengths_opt = hyperparam_init
            end = time.time()
            opt_fun = -999
            early_stopping_wrapper = EarlyStopping(
                None, patience=None, y_val=None, miniter=wandb.config.model["miniter"],
            )

        elif not wandb.config.model["scipy_opt"]:
            # calculate rho_opt

            # create opti problem
            if not wandb.config.base["reload"] and os.path.exists(opt_path):

                with open(opt_path, "rb") as f:
                    rho_lengths_opt = pkl.load(f)

            else:
                start = time.time()

                params = jnp.zeros_like(hyperparam_init)
                for i in range(d_perm):
                    print("######### Permutation {} #########".format(i))
                    optimizer = optax.adamw(learning_rate=1e-2)
                    early_stopping_wrapper = EarlyStopping(
                        copula_test_nll,
                        patience=wandb.config.model["patience"],
                        y_val=y_val[..., d_perm_inds[i]],
                        miniter=wandb.config.model["miniter"],
                        d_perm_i=i,
                        helper=helper,
                        d_perm_inds=jnp.arange(hyperparam_d)[None],
                        n_perm_inds=jnp.arange(len(y_opt))[None],
                    )
                    params_i, opt_fun = optimize_params(
                        y_opt[..., d_perm_inds[i]],
                        hyperparam_init[i : i + 1],
                        optimizer,
                        key_seq,
                        n_optim,
                        maxiter,
                        early_stopping_wrapper,
                    )
                    params = params.at[i : i + 1].set(params_i)
                # unscale hyperparameter
                rho_lengths_opt = jnp.array(params)

                with open(opt_path, "wb") as f:
                    pkl.dump(rho_lengths_opt, f)

        else:
            # Compiling
            if not wandb.config.base["reload"] and os.path.exists(opt_path):

                with open(opt_path, "rb") as f:
                    rho_lengths_opt = pkl.load(f)

            else:
                opt_funs = []
                params = jnp.zeros_like(hyperparam_init)
                for i in range(d_perm):
                    print("######### Permutation {} #########".format(i))
                    early_stopping_wrapper = EarlyStopping(
                        cop_AR.negpreq_jointloglik_perm,
                        patience=wandb.config.model["patience"],
                        y_val=y_val[..., d_perm_inds[i]],
                        miniter=wandb.config.model["miniter"],
                        d_perm_i=i,
                        helper=helper,
                        d_perm_inds=jnp.arange(hyperparam_d)[None],
                        n_perm_inds=jnp.arange(len(y_opt))[
                            None
                        ],  # n_perm_inds,  # [:n_perm_optim, :n_optim],
                    )

                    opt = minimize(
                        fun=cop_AR.fun_jll_perm_sp,
                        x0=hyperparam_init[i : i + 1],
                        args=(
                            y_opt[..., d_perm_inds[i]],
                            jnp.arange(hyperparam_d)[None],
                            n_perm_inds[
                                np.random.choice(
                                    len(n_perm_inds),
                                    replace=False,
                                    size=min(len(n_perm_inds), n_perm_optim),
                                )
                            ][:, :n_optim],
                            n_optim,
                        ),
                        jac=cop_AR.grad_jll_perm_sp,
                        method="SLSQP",
                        options={"maxiter": maxiter, "ftol": 1e-4},
                        callback=early_stopping_wrapper.callback,
                        bounds=bounds,
                    )

                    # check optimization succeeded
                    if opt.success == False:
                        print("Optimization failed")

                    # unscale hyperparameter
                    if early_stopping_wrapper.early_stop:
                        params = params.at[i].set(early_stopping_wrapper.best_params)
                    else:
                        params = params.at[i : i + 1].set(opt.x)
                    opt_funs.append(opt.fun)
                opt_fun = np.mean(opt_funs)
                rho_lengths_opt = jnp.array(params)

                os.makedirs(
                    os.path.join(utils.get_project_root(), "checkpoints"), exist_ok=True
                )
                with open(opt_path, "wb") as f:
                    pkl.dump(rho_lengths_opt, f)

        if wandb.config.model["diff"] in ["none", "dim"]:
            rho_lengths_opt = 1 / (1 + jnp.exp(rho_lengths_opt))
        # elif (
        #     wandb.config.model["diff"] == "none"
        #     and (wandb.config.base["class"] or wandb.config.base["regress"])
        # ) or wandb.config.model["diff"] == "dim":
        #     rho_lengths_opt = 1 / (1 + jnp.exp(rho_lengths_opt))
        elif wandb.config.model["diff"] == "eucl-dim":
            rho_lengths_opt = (
                jnp.array(rho_lengths_opt)
                .at[:, :hyperparam_d]
                .set(1 / (1 + jnp.exp(rho_lengths_opt[:, :hyperparam_d])))
            )
            rho_lengths_opt = rho_lengths_opt.at[:, hyperparam_d:].set(
                jnp.exp(rho_lengths_opt[:, hyperparam_d:])
            )
        else:
            rho_lengths_opt = (
                jnp.array(rho_lengths_opt)
                .at[:, 0]
                .set(1 / (1 + jnp.exp(rho_lengths_opt[:, 0])))
            )
            rho_lengths_opt = rho_lengths_opt.at[:, 1:].set(
                jnp.exp(rho_lengths_opt[:, 1:])
            )

            end = time.time()
            print("Optimization time: {}s".format(round(end - start, 3)))
            wandb.log({"optimization_time": round(end - start, 3)})

            print({"rho": rho_lengths_opt[:, : -hyperparam_d + 1]})
            print(
                {
                    f"lengthscale_{i}": rho_lengths_opt[:, i]
                    for i in range(1, max(rho_lengths_opt.shape[-1], 5))
                }
            )
            wandb.log({"rho": rho_lengths_opt[:, : -hyperparam_d + 1]})
            wandb.log(
                {
                    f"lengthscale_{i}": rho_lengths_opt[:, i]
                    for i in range(
                        rho_lengths_opt.shape[-1] - hyperparam_d + 1,
                        rho_lengths_opt.shape[-1],
                    )
                }
            )

    elif wandb.config.model["diff"] == "nice_net":
        key, subkey = split(key, 1 + 1)

        params, log_pdf, sample = network(subkey, hyperparam_d)

        opt_init, opt_update, get_params = (
            None,
            None,
            None,
        )  # optimizers.adam(step_size=1e-4)
        opt_state = opt_init(params)

        import itertools

        def loss(params, inputs):
            return -log_pdf(params, inputs).mean()

        @jit
        def step(i, opt_state, inputs):
            params = get_params(opt_state)
            gradients = grad(loss)(params, inputs)
            return opt_update(i, gradients, opt_state)

        # %%
        itercount = itertools.count()
        num_epochs = 100
        batch_size = 100

        y_nice = y.copy()
        for epoch in range(num_epochs):

            key, subkey = split(key, 1)
            y_nice = permutation(subkey, y_nice)
            for batch_index in range(0, len(y_nice), batch_size):
                opt_state = step(
                    next(itercount),
                    opt_state,
                    y_nice[batch_index : batch_index + batch_size],
                )

            params = get_params(opt_state)
            key, subkey = split(key, 1)

    else:
        raise NotImplementedError

    print("Fitting...")
    start = time.time()

    rho, lengths = cop_AR.get_rho_params_wo_transform(rho_lengths_opt)

    if wandb.config.model["diff"] == "net":
        helper = (jnp.arange(wandb.config.data["d"]), helper[1])

    if wandb.config.data["low_mem"] < 0:
        filename = f"{utils.get_project_root()}/tmp/{opt_path.split('.')[0].split('/')[-1]}_vn.npy"
        vn_perm = np.memmap(
            filename,
            dtype="float32",
            mode="w+",
            shape=(
                d_perm,
                n_perm,
                y.shape[0],
                y.shape[1] if not wandb.config.base["class"] else 1,
            ),
        )
        for d_idx, d_permi in enumerate(d_perm_inds):
            vn_perm[d_idx : d_idx + 1] = update_pn_loop_perm(
                rho[d_idx : d_idx + 1],
                cop_AR.slice_lengths(cop_AR.slice_lengths(lengths, d_idx), None),
                y,
                d_permi[None],
                n_perm_inds,
                helper,
            )[0].block_until_ready()
            vn_perm.flush()
        vn_perm = np.memmap(
            filename,
            dtype="float32",
            mode="r",
            shape=(
                d_perm,
                n_perm,
                y.shape[0],
                y.shape[1] if not wandb.config.base["class"] else 1,
            ),
        )
    else:
        vn_perm = update_pn_loop_perm(
            rho, lengths, y, d_perm_inds, n_perm_inds, helper,
        )[0].block_until_ready()

    end = time.time()
    print("Fit time: {}s".format(round(end - start, 3)))
    wandb.log({"fit_time": round(end - start, 3)})

    copula_density_obj = namedtuple(
        "copula_density_obj",
        [
            "vn_perm",
            "rho_lengths_opt",
            "preq_loglik",
            "y_perm",
            "d_perm_inds",
            "n_perm_inds",
        ],
    )
    return copula_density_obj(
        vn_perm, rho_lengths_opt, -opt_fun, y, d_perm_inds, n_perm_inds
    )


@jit
def copula_test_nll(
    rho_lengths_opt,
    y_test,
    vn_perm,
    y_perm,
    d_perm_inds,
    n_perm_inds,
    helper=None,
    bern=False,
):
    if wandb.config.base["class"]:
        import models.copula_AR_functions.copula_classification_functions_old as cop_AR
    elif wandb.config.base["regress"]:
        import models.copula_AR_functions.copula_regression_functions_old as cop_AR
    else:
        import models.copula_AR_functions as cop_AR
    if helper is None and wandb.config.model["diff"] == "net":
        helper = (
            jnp.arange(wandb.config.data["d"]),
            jnp.tri(wandb.config.data["d"]),
        )

    rho, lengths = cop_AR.get_rho_params(rho_lengths_opt)
    (_, logpdf_joints,) = cop_AR.update_ptest_loop_perm_av(
        vn_perm, rho, lengths, y_perm, y_test, d_perm_inds, n_perm_inds, 0, helper
    )
    return -logpdf_joints.mean(-1).mean()


# Predict on test data using copula object
def predict_copula_density(copula_density_obj, y_test, bern=False):
    if wandb.config.base["class"]:
        import models.copula_AR_functions.copula_classification_functions_old as cop_AR
    elif wandb.config.base["regress"]:
        import models.copula_AR_functions.copula_regression_functions_old as cop_AR
    else:
        import models.copula_AR_functions as cop_AR
    y_test_perm = y_test

    if wandb.config.model["diff"] == "net":
        helper = (
            jnp.arange(wandb.config.data["d"]),
            jnp.tri(wandb.config.data["d"]),
        )
    else:
        helper = None
    rho, lengths = cop_AR.get_rho_params_wo_transform(
        copula_density_obj.rho_lengths_opt
    )
    if wandb.config.model["diff"] == "knn":
        tree = KDTree(copula_density_obj.y_perm)
        helper = tree.query(
            copula_density_obj.y_perm,
            k=wandb.config.model["n_knn"],
            return_distance=False,
            sort_results=False,
        )
        helper_test = tree.query(
            y_test,
            k=wandb.config.model["n_knn"],
            return_distance=False,
            sort_results=False,
        )
        logcdf_conditionals, logpdf_joints = cop_AR.update_ptest_loop_perm_av(
            copula_density_obj.vn_perm,
            rho,
            lengths,
            helper,
            y_test_perm,
            copula_density_obj.d_perm_inds,
            copula_density_obj.n_perm_inds,
            helper_test,
        )
    elif wandb.config.data["low_mem"] < 0:
        logcdf_conditionals, logpdf_joints = (
            jnp.zeros_like(y_test_perm) if not wandb.config.base["class"] else 1,
            jnp.zeros(
                (
                    y_test_perm.shape[0],
                    y_test_perm.shape[1] if not wandb.config.base["class"] else 1,
                ),
            ),
        )
        for d_idx, d_perm in tqdm(enumerate(copula_density_obj.d_perm_inds)):
            logcdf_conditionals_, logpdf_joints_ = cop_AR.update_ptest_loop_perm_av(
                copula_density_obj.vn_perm[d_idx : d_idx + 1],
                rho,
                lengths,
                copula_density_obj.y_perm,
                y_test_perm,
                d_perm[None],
                copula_density_obj.n_perm_inds,
                0,
                helper,
            )
            if not wandb.config.base["class"]:
                logcdf_conditionals += logcdf_conditionals_
            logpdf_joints += logpdf_joints_
        if not wandb.config.base["class"]:
            logcdf_conditionals -= len(copula_density_obj.d_perm_inds)
        logpdf_joints -= len(copula_density_obj.d_perm_inds)
    else:
        logcdf_conditionals, logpdf_joints = cop_AR.update_ptest_loop_perm_av(
            copula_density_obj.vn_perm,
            rho,
            lengths,
            copula_density_obj.y_perm,
            y_test_perm,
            copula_density_obj.d_perm_inds,
            copula_density_obj.n_perm_inds,
            0,
            helper,
        )
    logpdf_joints = logpdf_joints.block_until_ready()  # for accurate timing

    return logcdf_conditionals, logpdf_joints


# Predict on test data using copula object
def predict_copula_density_per_perm(copula_density_obj, y_test, bern=False):
    if wandb.config.base["class"]:
        import models.copula_AR_functions.copula_classification_functions_old as cop_AR
    elif wandb.config.base["regress"]:
        import models.copula_AR_functions.copula_regression_functions_old as cop_AR
    else:
        import models.copula_AR_functions as cop_AR
    y_test_perm = y_test

    if wandb.config.model["diff"] == "net":
        helper = (
            jnp.arange(wandb.config.data["d"]),
            jnp.tri(wandb.config.data["d"]),
        )
    else:
        helper = None
    rho, lengths = cop_AR.get_rho_params_wo_transform(
        copula_density_obj.rho_lengths_opt
    )
    if wandb.config.model["diff"] == "knn":
        tree = KDTree(copula_density_obj.y_perm)
        helper = tree.query(
            copula_density_obj.y_perm,
            k=wandb.config.model["n_knn"],
            return_distance=False,
            sort_results=False,
        )
        helper_test = tree.query(
            y_test,
            k=wandb.config.model["n_knn"],
            return_distance=False,
            sort_results=False,
        )
        logcdf_conditionals, logpdf_joints = cop_AR.update_ptest_loop_perm_per_perm(
            copula_density_obj.vn_perm,
            rho,
            lengths,
            helper,
            y_test_perm,
            copula_density_obj.d_perm_inds,
            copula_density_obj.n_perm_inds,
            helper_test,
        )
    elif wandb.config.data["low_mem"] < 0:
        logcdf_conditionals, logpdf_joints = (
            jnp.zeros_like(y_test_perm),
            jnp.zeros_like(y_test_perm),
        )
        max_val = -99999
        for d_idx, d_perm in tqdm(enumerate(copula_density_obj.d_perm_inds)):
            logcdf_conditionals_, logpdf_joints_ = cop_AR.update_ptest_loop_perm_av(
                copula_density_obj.vn_perm[d_idx],
                rho,
                lengths,
                copula_density_obj.y_perm,
                y_test_perm,
                d_perm[None],
                copula_density_obj.n_perm_inds,
                0,
                helper,
            )
            ll = logpdf_joints_[:, -1].mean()
            if logpdf_joints_ > max_val:
                logcdf_conditionals = logcdf_conditionals_
                logpdf_joints = logpdf_joints_
                max_val = ll
    else:
        logcdf_conditionals, logpdf_joints = cop_AR.update_ptest_loop_perm_per_perm(
            copula_density_obj.vn_perm,
            rho,
            lengths,
            copula_density_obj.y_perm,
            y_test_perm,
            copula_density_obj.d_perm_inds,
            copula_density_obj.n_perm_inds,
            0,
            helper,
        )
    logcdf_conditionals = logcdf_conditionals.block_until_ready()  # for accurate timing

    return logcdf_conditionals, logpdf_joints


# Predict on test data using copula object
def smc_sample_from_copula(
    copula_density_obj,
    init_samples,
    bern=False,
    helper=None,
    best_d=None,
    perm_while_sampling=True,
):
    if wandb.config.base["class"]:
        import models.copula_AR_functions.copula_classification_functions_old as cop_AR
    elif wandb.config.base["regress"]:
        import models.copula_AR_functions.copula_regression_functions_old as cop_AR
    else:
        import models.copula_AR_functions as cop_AR

    # print("Sampling...")
    rho, lengths = cop_AR.get_rho_params_wo_transform(
        copula_density_obj.rho_lengths_opt
    )
    if wandb.config.model["diff"] == "net":
        helper = (
            jnp.arange(wandb.config.data["d"]),
            jnp.tri(wandb.config.data["d"]),
        )
    else:
        helper = None

    if best_d is None:
        y_test, logpdf_joints, n_resampl, ess = cop_AR.sample_p_loop_perm(
            copula_density_obj.vn_perm,
            rho,
            lengths,
            copula_density_obj.y_perm,
            init_samples,
            copula_density_obj.d_perm_inds,
            copula_density_obj.n_perm_inds,
            wandb.config.base["seed"],
            helper,
        )
        y_test = y_test.block_until_ready()  # for accurate timing
    elif perm_while_sampling:
        (y_test, logpdf_joints, n_resampl, ess,) = cop_AR.sample_p_loop_perm_loop_ytest(
            copula_density_obj.vn_perm,
            rho,
            lengths,
            copula_density_obj.y_perm,
            init_samples,
            copula_density_obj.d_perm_inds,
            copula_density_obj.n_perm_inds,
            wandb.config.base["seed"],
            helper,
        )
        y_test = y_test.block_until_ready()  # for accurate timing
    else:
        y_test, logpdf_joints, n_resampl, ess = cop_AR.sample_p_loop_perm(
            copula_density_obj.vn_perm[best_d : best_d + 1],
            rho[best_d : best_d + 1],
            cop_AR.slice_lengths(cop_AR.slice_lengths(lengths, best_d), None),
            copula_density_obj.y_perm,
            init_samples,
            copula_density_obj.d_perm_inds[best_d : best_d + 1],
            copula_density_obj.n_perm_inds,
            wandb.config.base["seed"],
            helper,
        )
        y_test = y_test.block_until_ready()  # for accurate timing

    return y_test, logpdf_joints, n_resampl, ess


# Predict on test data using copula object
def impute_from_copula(copula_density_obj, init_samples, bern=False):
    if wandb.config.base["class"]:
        import models.copula_AR_functions.copula_classification_functions_old as cop_AR
    elif wandb.config.base["regress"]:
        import models.copula_AR_functions.copula_regression_functions_old as cop_AR
    else:
        import models.copula_AR_functions as cop_AR

    print("Sampling...")
    start = time.time()
    rho, lengths = cop_AR.get_rho_params_wo_transform(
        copula_density_obj.rho_lengths_opt
    )
    if wandb.config.model["diff"] == "net":
        helper = (
            jnp.arange(wandb.config.data["d"]),
            jnp.tri(wandb.config.data["d"]),
        )
    else:
        helper = None
    y_test, logpdf_joints, n_resampl = cop_AR.sample_p_loop_perm(
        copula_density_obj.vn_perm,
        rho,
        lengths,
        copula_density_obj.y_perm,
        init_samples,
        copula_density_obj.d_perm_inds,
        copula_density_obj.n_perm_inds,
        wandb.config.base["seed"],
        helper,
    )
    y_test = y_test.block_until_ready()  # for accurate timing
    end = time.time()
    print("Sampling time: {}s".format(round(end - start, 3)))
    wandb.log({"Sampling time": round(end - start, 3)})
    return y_test, logpdf_joints, n_resampl


# Sample from predcitive density p_n
def sample_copula_density(
    copula_density_obj,
    B_samples,
    seed=wandb.config.base["seed"],
    bern=False,
    best_d=None,
):
    if wandb.config.base["class"]:
        import models.copula_AR_functions.copula_classification_functions_old as cop_AR
    elif wandb.config.base["regress"]:
        import models.copula_AR_functions.copula_regression_functions_old as cop_AR
    else:
        import models.copula_AR_functions as cop_AR
    if wandb.config.model["diff"] == "net":
        helper = (
            jnp.arange(wandb.config.data["d"]),
            jnp.tri(wandb.config.data["d"]),
        )
    else:
        helper = None
    d = np.shape(copula_density_obj.vn_perm)[-1]

    rho, lengths = cop_AR.get_rho_params_wo_transform(
        copula_density_obj.rho_lengths_opt
    )
    # Compiling
    # print("Compiling...")
    # start = time.time()
    # temp = cop_AR.compute_quantile_pn_av(
    #     copula_density_obj.vn_perm,
    #     rho,
    #     lengths,
    #     copula_density_obj.y_perm,
    #     0.5 * np.ones(d),
    #     copula_density_obj.d_perm_inds,
    #     copula_density_obj.n_perm_inds,
    #     0,
    #     helper,
    # )
    # end = time.time()
    # print("Compilation time: {}s".format(round(end - start, 3)))

    # Initialize
    y_samp = np.zeros((B_samples, d))
    err = np.zeros(B_samples)
    n_iter = np.zeros(B_samples)

    # Simulate uniform random variables
    np.random.seed(seed)
    un = np.random.rand(B_samples, d)

    from utils.bivariate_copula import ndtri_

    # Sampling
    print("Sampling...")
    start = time.time()
    for i in tqdm(range(B_samples)):
        y_samp[i], err[i], n_iter[i] = cop_AR.compute_quantile_pn_av(
            copula_density_obj.vn_perm,
            rho,
            lengths,
            copula_density_obj.y_perm,
            un[i],
            copula_density_obj.d_perm_inds,
            copula_density_obj.n_perm_inds,
            0,
            helper,
        )
    end = time.time()
    print("Sampling time: {}s".format(round(end - start, 3)))
    print(f"Max abs error in cdf: {np.sqrt(np.max(err)):.2e}")
    return y_samp, err, n_iter

