from scipy.optimize import minimize
from collections import namedtuple
import time
import numpy as np

# import jax
import jax.numpy as jnp
from jax import vmap
from jax.random import permutation, PRNGKey, split

import wandb

# import package functions

if wandb.config.base["class"]:
    import models.old_copula.ori_copula_ar_class as mvcc

    # import models.copula_AR_functions.copula_classification_functions_old as mvcc
elif wandb.config.base["regress"]:
    # import models.old_copula.ori_regress as mvcc

    import models.copula_AR_functions.copula_regression_functions_old as mvcc

import wandb

# from . import sample_copula_classification_functions as samp_mvcc

### Fitting ###
# Compute overhead v_{1:n}, return fit copula object for prediction
def fit_copula_classification(
    y, x, n_perm=10, seed=20, n_perm_optim=None, single_x_bandwidth=True
):
    # Set seed for scipy
    np.random.seed(seed)

    # Generate random permutations
    key = PRNGKey(seed)

    key, *subkey = split(key, n_perm + 1)
    subkey = jnp.array(subkey)
    n_perm_inds = vmap(permutation, (0, None))(subkey, jnp.arange(len(y)))

    assert wandb.config.model["d_perm"] == 1
    key, *subkey = split(key, wandb.config.model["d_perm"] + 1)
    subkey = jnp.array(subkey)
    # d_perm_inds = jnp.arange(x.shape[1])[
    #     None
    # ]  #!vmap(permutation, (0, None))(subkey, jnp.arange(x.shape[1]))
    d_perm_inds = vmap(permutation, (0, None))(subkey, jnp.arange(x.shape[1]))
    y_perm = jnp.take(y, n_perm_inds, axis=0)[..., None]
    x_perm = jnp.take(x, n_perm_inds, axis=0)
    x_perm = x_perm[..., d_perm_inds[0]]
    y_perm_opt = y_perm
    x_perm_opt = x_perm

    # Initialize parameter and put on correct scale to lie in [0,1]
    d = jnp.shape(x)[1]
    if single_x_bandwidth == True:
        rho_init = wandb.config.model["init_rho"] * jnp.ones(2)
    else:
        rho_init = wandb.config.model["init_rho"] * jnp.ones(d + 1)

    # rho_init = jnp.array(  #!!!!
    #     [
    #         0.6216423,
    #         0.6251501,
    #         0.22229734,
    #         0.2966467,
    #         0.8669622,
    #         0.3453144,
    #         0.595976,
    #         0.6022529,
    #         0.63369936,
    #         0.7806856,
    #         0.57049084,
    #         0.7935593,
    #         0.8564825,
    #         0.7494339,
    #         0.19762649,
    #         0.57316947,
    #         0.7577445,
    #         0.60531294,
    #         0.5958672,
    #         0.9150342,
    #         0.546921,
    #         0.78235763,
    #         0.8821759,
    #         0.64792246,
    #         0.6132756,
    #         0.81249267,
    #         0.8511739,
    #         0.851263,
    #         0.58847743,
    #         0.73223084,
    #         0.6212988,
    #     ]
    # )
    hyperparam_init = jnp.log(1 / rho_init - 1)

    # calculate rho_opt
    # either use all permutations or a selected number to fit bandwidth
    if n_perm_optim is None:
        y_perm_opt = y_perm
        x_perm_opt = x_perm
    else:
        y_perm_opt = y_perm[0:n_perm_optim]
        x_perm_opt = x_perm[0:n_perm_optim]

    # # Compiling
    # print("Compiling...")
    # start = time.time()
    # # temp = mvcc.fun_grad_ccll_perm_sp(hyperparam_init,y_perm_opt,x_perm_opt) #value and grad is slower for many parameters
    # temp = mvcc.fun_jll_perm_sp(hyperparam_init, y_perm_opt, x_perm_opt)
    # temp = mvcc.grad_ccll_perm_sp(hyperparam_init, y_perm_opt, x_perm_opt)
    # temp = mvcc.update_pn_loop_perm(
    #     hyperparam_init[0], hyperparam_init[1:], y_perm, x_perm
    # )[0].block_until_ready()
    # end = time.time()
    # print("Compilation time: {}s".format(round(end - start, 3)))
    # x_perm, y_perm = x_perm[0], y_perm[0]  #!

    if wandb.config.model["maxiter"]:
        print("Optimizing...")
        start = time.time()
        # Condit preq loglik
        opt = minimize(
            fun=mvcc.fun_jll_perm_sp,
            x0=hyperparam_init[None],
            args=(
                y_perm_opt,
                x_perm_opt,
                # jnp.concatenate((x_perm, y_perm), axis=1),  #!
                # jnp.arange(x_perm.shape[-1] + 1)[None],  #!
                # jnp.arange(y.shape[0])[None],  #!
            ),
            jac=mvcc.grad_jll_perm_sp,
            method="SLSQP",
            options={"maxiter": wandb.config.model["maxiter"], "ftol": 1e-4},
        )

        # check optimization succeeded
        if opt.success == False:
            print("Optimization failed")

        # unscale hyperparameter
        hyperparam_opt = opt.x
        opt_fun = opt.fun
        end = time.time()
        print("Optimization time: {}s".format(round(end - start, 3)))
    else:
        hyperparam_opt = hyperparam_init
        opt_fun = 999
    rho_opt = 1 / (1 + jnp.exp(hyperparam_opt[0]))
    rho_opt_x = 1 / (1 + jnp.exp(hyperparam_opt[1:]))

    print("Fitting...")
    start = time.time()
    # rho_opt, rho_opt_x = rho_opt[None], rho_opt_x[None]  #!
    log_vn, logpmf_yn_perm, *_ = mvcc.update_pn_loop_perm(
        rho_opt,
        rho_opt_x,
        y_perm,
        x_perm,
        # rho_opt,  #!
        # rho_opt_x,  #!
        # jnp.concatenate((x_perm, y_perm), axis=1),  #!
        # jnp.arange(x_perm.shape[-1] + 1)[None],  #!
        # jnp.arange(y.shape[0])[None],  #!
        # None,  #!
    )
    log_vn = log_vn.block_until_ready()
    end = time.time()
    print("Fit time: {}s".format(round(end - start, 3)))

    copula_classification_obj = namedtuple(
        "copula_classification_obj",
        [
            "log_vn_perm",
            "logpmf_yn_perm",
            "rho_opt",
            "rho_x_opt",
            "preq_loglik",
            "y_perm",
            "x_perm",
            "d_perm_inds",
        ],
    )
    return copula_classification_obj(
        log_vn,
        logpmf_yn_perm,
        rho_opt,
        rho_opt_x,
        -opt_fun,
        y_perm,
        x_perm,
        d_perm_inds,
    )


# Returns p(y=1 |x)
def predict_copula_density(copula_classification_obj, x_test):
    # code loop for now, can speed up to use indices
    n_perm = np.shape(copula_classification_obj.x_perm)[0]
    n = np.shape(copula_classification_obj.x_perm)[1]
    n_test = np.shape(x_test)[0]
    logk_xx = np.zeros((n_perm, n, n_test))
    d_perm_inds = jnp.arange(copula_classification_obj.x_perm.shape[-1] + 1)

    print("Predicting...")
    start = time.time()
    # _, logpmf = mvcc.update_ptest_loop_perm_av(
    logpmf = mvcc.update_ptest_loop_perm_av(  #!
        copula_classification_obj.log_vn_perm,
        copula_classification_obj.rho_opt,
        copula_classification_obj.rho_x_opt,
        copula_classification_obj.y_perm  #!
        if wandb.config.base["class"]  #!
        else copula_classification_obj.x_perm,  #!
        copula_classification_obj.x_perm  #!
        if wandb.config.base["class"]  #!
        else x_test[..., -1:],  #!
        x_test[..., copula_classification_obj.d_perm_inds[0]]  #![:, :-1]
        if wandb.config.base["class"]
        else x_test[:, :-1],
        # jnp.concatenate(
        #     (copula_classification_obj.x_perm, copula_classification_obj.y_perm), axis=1
        # ),  #!
        # x_test[..., d_perm_inds],  #!
        # jnp.arange(copula_classification_obj.x_perm.shape[-1] + 1)[None],  #!
        # jnp.arange(copula_classification_obj.x_perm.shape[-2])[None],  #!
        # None,  #!
        # None,  #!
    )
    logpmf = logpmf.block_until_ready()  # for accurate timing
    end = time.time()
    print("Prediction time: {}s".format(round(end - start, 3)))
    return logpmf

