import numpy as np
import scipy as sp
from functools import partial

# import jax functions
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap, jacfwd, jacrev, random
from jax.scipy.stats import norm
from jax.lax import fori_loop

# from jax.ops import index_update

# import package functions
# from . import copula_density_functions as mvcd
from utils.BFGS import minimize_BFGS
from utils.bivariate_copula import ndtri_

from time import time
from tqdm import tqdm

#  https://github.com/edfong/MP/blob/6348a25e16ca38b4f84bbdc7e4c94f45ae485c3b/pr_copula/sample_copula_density_functions.py

from .copula_ar_test import update_ptest_loop_perm_av


@jit
def calc_pn_av_err2(
    y0, vn_perm, rho, lengths, y_perm, quantile, d_perm_inds, n_perm_inds, index, helper
):
    n = jnp.shape(vn_perm)[-2]
    d = jnp.shape(vn_perm)[-1]

    quantile = quantile.reshape(1, d)

    # compute p_n(y0) through perm avg
    y_test = y0.reshape(1, d)

    logcdf_conditionals_ytest, logpdf_joints_ytest = update_ptest_loop_perm_av(
        vn_perm, rho, lengths, y_perm, y_test, d_perm_inds, n_perm_inds, index, helper,
    )  # can sample from each permutation independently
    err2 = jnp.sum((jnp.exp(logcdf_conditionals_ytest) - quantile) ** 2)
    return err2


grad_pn_av_err2 = jit(grad(calc_pn_av_err2))

# Find quantile P_n^{-1}(u), which can be used for sampling
@jit
def compute_quantile_pn_av(
    vn_perm, rho, lengths, y_perm, quantile, d_perm_inds, n_perm_inds, index, helper,
):  # delta = 0.5 works well!
    d = jnp.shape(vn_perm)[-1]

    # unif rv
    y0_init = ndtri_(quantile)

    # function wrappers for BFGS
    # @jit
    def fun(y0):  # wrapper around function evaluation and grad
        return calc_pn_av_err2(
            y0,
            vn_perm,
            rho,
            lengths,
            y_perm,
            quantile,
            d_perm_inds,
            n_perm_inds,
            index,
            helper,
        )

    y_samp, err2, n_iter, _ = minimize_BFGS(fun, y0_init, delta_B_init=0.5)

    return y_samp, err2, n_iter

