import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap, jacfwd, jacrev, random, remat, value_and_grad
from jax.scipy.special import ndtri, erfc, logsumexp
from jax.scipy.stats import norm
from jax import random
from jax.lax import fori_loop, scan
import jax
import optax
from jax.random import PRNGKey, permutation, split


# from jax.ops import index, index_add, index_update
import scipy as osp
from functools import partial
from scipy.optimize import minimize, root
import wandb
from tqdm import tqdm_notebook

# from . import copula_density_functions as mvcd

from . import (
    compute_rho_d,
    compute_rho_d_single,
    get_rho_params,
    get_rho_params_wo_transform,
    slice_lengths,
)

# from .copula_ar_optim import get_rho_params, get_rho_params_wo_transform

## Conditional method classification ##

### Utility functions ###


# Initialize marginal p_0
def init_marginals_single(y):
    n = jnp.shape(y)[0]

    ##discrete
    p_yplot = 0.5
    logpmf_init_marginals = jnp.array([jnp.log(p_yplot)])

    # clip outliers
    eps = 1e-6
    logpmf_init_marginals = jnp.clip(
        logpmf_init_marginals, jnp.log(eps), jnp.log(1 - eps)
    )
    ##
    return logpmf_init_marginals


init_marginals = jit(vmap(init_marginals_single, (0)))

# Bernoulli 'copula' update; note that logpmf1 = p(y=1 | x) so works directly with p(y=1 |x )
def update_copula_single(logpmf1, log_v, y_new, logalpha, rho):  #!
    eps = 5e-5
    logpmf1 = jnp.clip(
        logpmf1, jnp.log(eps), jnp.log(1 - eps)
    )  # clip u before passing to bicop
    log_v = jnp.clip(
        log_v, jnp.log(eps), jnp.log(1 - eps)
    )  # clip u before passing to bicop

    log1alpha = jnp.log1p(jnp.clip(-jnp.exp(logalpha), -1 + eps, jnp.inf))
    log1_v = jnp.log1p(jnp.clip(-jnp.exp(log_v), -1 + eps, jnp.inf))

    min_logu1v1 = jnp.min(jnp.array([logpmf1, log_v]))

    ##Bernoulli update
    frac = y_new * jnp.exp(min_logu1v1 - logpmf1 - log_v) + (1 - y_new) * (
        1 / jnp.exp(log1_v) - jnp.exp(min_logu1v1 - logpmf1 - log1_v)
    )  # make this more accurate?
    kyy_ = 1 - rho + rho * frac
    kyy_ = jnp.clip(kyy_, eps, jnp.inf)

    logkyy_ = jnp.log(kyy_)
    logpmf1_new = jnp.logaddexp(log1alpha, (logalpha + logkyy_)) + logpmf1

    return logpmf1_new


# Vmap over rho/alpha as well as they depend on x
# update_copula = jit(vmap(update_copula_single, (0, None, None, 0, None)))
update_copula = jit(vmap(update_copula_single, (0, None, None, 0, 0)))

# Compute log k_xx for a single data point
@jit
def calc_logkxx_single(x, x_new, rho_x):  #!
    logk_xx = -0.5 * jnp.sum(jnp.log(1 - rho_x ** 2)) - jnp.sum(
        (0.5 / (1 - rho_x ** 2))
        * (((rho_x ** 2) * (x ** 2 + x_new ** 2) - 2 * rho_x * x * x_new))
    )
    return logk_xx


calc_logkxx = jit(vmap(calc_logkxx_single, (0, None, 0)))  #!
# calc_logkxx = jit(vmap(calc_logkxx_single, (0, None, None)))  #!
calc_logkxx_test = jit(vmap(calc_logkxx, (None, 0, None)))  #!

### ###

### Functions to calculate overhead v_{1:n} ###

# Compute v_i for a single datum
@jit
def update_pn(carry, i):  #!

    log_vn, logpmf1_yn, preq_loglik, y, x, rho, lengths = carry

    # Compute new x
    y_new = y[i]
    x_new = x[i]
    logalpha = jnp.log(2.0 - (1 / (i + 1))) - jnp.log(i + 2)
    rhos = compute_rho_d(
        jnp.concatenate((x, y), -1),
        jnp.concatenate((x_new, y_new)),
        rho,
        lengths,
        None,
    )
    rho_d, rho_x = rhos[..., 0], rhos[..., 1:]

    # compute x rhos/alphas
    eps = 5e-5
    logk_xx = calc_logkxx(x, x_new, rho_x)  ##!!!!!!!
    logalphak_xx = logalpha + logk_xx
    log1alpha = jnp.log1p(jnp.clip(-jnp.exp(logalpha), -1 + eps, jnp.inf))
    logalpha_x = (logalphak_xx) - (
        jnp.logaddexp(log1alpha, logalphak_xx)
    )  # alpha*k_xx /(1-alpha + alpha*k_xx)

    # clip for numerical stability to prevent NaNs
    logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps))

    # add p1 or (1-p1) depending on what y_new is
    temp = y_new * logpmf1_yn[i, -1] + (1 - y_new) * jnp.log1p(
        jnp.clip(-jnp.exp(logpmf1_yn[i, -1]), -1 + eps, jnp.inf)
    )
    preq_loglik = preq_loglik.at[i].set(temp)

    #
    log_v = logpmf1_yn[i]
    log_vn = log_vn.at[i].set(log_v)

    logpmf1_yn = update_copula(logpmf1_yn, log_v, y_new, logalpha_x, rho_d)  ##!!!!!!!
    carry = log_vn, logpmf1_yn, preq_loglik, y, x, rho, lengths
    return carry, i


# Scan over y_{1:n}
@jit
def update_pn_scan(carry, rng):
    return scan(update_pn, carry, rng)


# Compute v_{1:n}
@jit
def update_pn_loop(rho, rho_x, y_cat, d_perm_inds, n_perm_inds, helper=None):  #!
    n = len(n_perm_inds)  #!!
    preq_loglik = jnp.zeros((n, 1))  # p_n(y_{n+1}=1 | x_{n+1})

    x, y = y_cat[..., :-1], y_cat[..., -1:]

    x = jnp.take(jnp.take(x, n_perm_inds, axis=0), d_perm_inds[..., :-1], axis=-1)
    y = jnp.take(y, n_perm_inds, axis=0)

    # initialize cdf/pdf
    logpmf1_yn = init_marginals(y)
    log_vn = jnp.zeros((n, 1))

    carry = log_vn, logpmf1_yn, preq_loglik, y, x, rho, rho_x
    rng = jnp.arange(n)
    carry, rng = update_pn_scan(carry, rng)

    log_vn, logpmf1_yn, preq_loglik, *_ = carry

    return log_vn, logpmf1_yn, preq_loglik


update_pn_loop_perm_ = jit(vmap(update_pn_loop, (None, None, None, None, 0, None)))
update_pn_loop_perm = jit(vmap(update_pn_loop_perm_, (0, 0, None, 0, None, None)))


### Functions for Optimizing prequential log likelihood ###

# Compute permutation-averaged conditional preq loglik
@jit
def negpreq_jointloglik_perm(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    if wandb.config.model["diff"] in ["eucl", "none", "extreme", "dim", "eucl-dim"]:  #!
        hyperparam = hyperparam.reshape((len(d_perm_inds), -1))

    rho, rho_x = get_rho_params(hyperparam)

    n = len(n_perm_inds[0])

    # Compute prequential loglik
    log_vn, _, preq_loglik = update_pn_loop_perm(
        rho, rho_x, y_perm, d_perm_inds, n_perm_inds, helper
    )
    # Average over permutations
    preq_loglik = jnp.mean(preq_loglik, axis=0)
    preq_loglik = jnp.mean(preq_loglik, axis=0)

    # Sum prequential terms
    preq_jointloglik = jnp.sum(preq_loglik[:, -1])  # only look at joint pdf
    if wandb.config.model["scipy_opt"]:
        return -preq_jointloglik / len(preq_loglik)
    else:
        return -preq_jointloglik / len(preq_loglik), (log_vn)


# Compute derivatives wrt hyperparameters
fun_grad_jll_perm = jit(value_and_grad(negpreq_jointloglik_perm))
grad_jll_perm = jit(grad(negpreq_jointloglik_perm))

# Functions for scipy (convert to numpy array)
def fun_jll_perm_sp(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    return np.array(
        negpreq_jointloglik_perm(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper)
    )


def grad_jll_perm_sp(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    return np.array(
        grad_jll_perm(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper)
    )  ####


def fun_grad_jll_perm_sp(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    value, grad = fun_grad_jll_perm(
        hyperparam, y_perm, d_perm_inds, n_perm_inds, helper
    )
    return (np.array(value), np.array(grad))


### ###

### Functions for computing p(y|x) on test points ###

# Update p(y|x) for a single test point and observed datum
@jit
def update_ptest_single(carry, i):  #!
    log_vn, logpmf_ytest, y, x, x_test, rho, lengths = carry

    y_new = y[i]
    x_new = x[i]
    logalpha = jnp.log(2.0 - (1 / (i + 1))) - jnp.log(i + 2)
    rhos = compute_rho_d_single(
        jnp.concatenate((x_test, jnp.zeros_like(y_new))),
        jnp.concatenate((x_new, y_new)),
        rho,
        lengths,
        None,
    )
    rho_d, rho_x = rhos[..., 0], rhos[..., 1:]

    # compute x rhos/alphas
    eps = 5e-5  # 1e-6 causes optimization to fail
    logk_xx = calc_logkxx_single(x_test, x_new, rho_x)
    logalphak_xx = logalpha + logk_xx
    log1alpha = jnp.log1p(jnp.clip(-jnp.exp(logalpha), -1 + eps, jnp.inf))
    logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha, logalphak_xx))

    # clip for numerical stability to prevent NaNs
    logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps))

    logpmf_ytest = update_copula_single(
        logpmf_ytest, log_vn[i], y_new, logalpha_x, rho_d
    )

    carry = log_vn, logpmf_ytest, y, x, x_test, rho, lengths
    return carry, i


# Scan over n observed data
@jit
def update_ptest_single_scan(carry, rng):
    return scan(update_ptest_single, carry, rng)


# Compute p(y) for a single test point and y_{1:n}
@jit
def update_ptest_single_loop(  #!
    log_vn, rho, rho_x, y_cat, d_perm_inds, n_perm_inds, x_test
):
    x, y = y_cat[..., :-1], y_cat[..., -1:]

    n = jnp.shape(y)[0]
    n_test = jnp.shape(x_test)[0]
    x_test = x_test[..., d_perm_inds[..., :-1]]
    # y = y[..., d_perm_inds]
    y = jnp.take(y, n_perm_inds, axis=0)
    x = x[..., d_perm_inds[..., :-1]]
    x = jnp.take(x, n_perm_inds, axis=0)

    logpmf1_ytest = init_marginals_single(np.zeros((n_test, 1)))

    carry = log_vn, logpmf1_ytest, y, x, x_test, rho, rho_x
    rng = jnp.arange(n)
    carry, rng = update_ptest_single_scan(carry, rng)
    log_vn, logpmf_ytest, y, x, x_test, rho, rho_x = carry

    return logpmf_ytest


update_ptest_single_loop_perm_ = jit(
    vmap(update_ptest_single_loop, (0, None, None, None, None, 0, None))
)  # vmap over vn_perm

update_ptest_single_loop_perm = jit(
    vmap(update_ptest_single_loop_perm_, (0, 0, 0, None, 0, None, None))
)  # vmap over vn_perm

# Average p(y|x) over permutations
@jit
def update_ptest_single_loop_perm_av(
    log_vn_perm, rho, rho_x, y_perm, x_test, d_perm_inds, n_perm_inds, index, helper,
):
    n_perm = len(n_perm_inds)
    d_perm = len(d_perm_inds)
    logpmf_ytest = update_ptest_single_loop_perm(
        log_vn_perm, rho, rho_x, y_perm, d_perm_inds, n_perm_inds, x_test
    )
    logpmf_ytest = logsumexp(logpmf_ytest, axis=0) - jnp.log(d_perm)
    logpmf_ytest = logsumexp(logpmf_ytest, axis=0) - jnp.log(n_perm)
    return None, logpmf_ytest


# Vmap over multiple test points
update_ptest_loop_perm_av = jit(
    vmap(
        update_ptest_single_loop_perm_av,
        (None, None, None, None, 0, None, None, None, None),
    )
)


def fit(
    y,
    params: optax.Params,
    optimizer: optax.GradientTransformation,
    key_seq,
    n_optim,
    maxiter,
    early_stopping_wrapper,
    helper=None,
    d_perm_iter=0,
) -> optax.Params:
    opt_state = optimizer.init(params)
    y_val_ori = early_stopping_wrapper.y_val

    @jax.jit
    def step(params, opt_state, y_batch, helper_batch):
        (loss_value, (vn)), grads = jax.value_and_grad(
            negpreq_jointloglik_perm, has_aux=True,
        )(
            params,
            y_batch,
            jnp.arange(y_batch.shape[-1])[None],
            jnp.arange(y_batch.shape[-2])[None],
            helper_batch,
        )
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value, vn

    loss_value = -999
    for i in range(maxiter):
        y_batch = permutation(next(key_seq), y)[0:n_optim]
        if wandb.config.model["diff"] == "arnet":
            params = {
                k: {"w": v["w"] * helper[k], "b": v["b"]} if k != "rho" else v
                for k, v in params.items()
            }
            helper_batch = None
        elif (
            wandb.config.model["diff"] == "net"
            and wandb.config.model["perm_while_training"]
        ):
            perm = jax.random.permutation(next(key_seq), wandb.config.data["d"])
            helper_batch = (perm, helper[1][:, perm])
            y_batch = y_batch[:, perm]
            early_stopping_wrapper.y_val = y_val_ori[:, perm]
        else:
            helper_batch = helper
        # dperm_subkey = next(key_seq)[None]
        # y_batch = vmap(permutation, (0, None, None))(dperm_subkey, y_batch, -1)
        params, opt_state, loss_value, vn = step(
            params, opt_state, y_batch, helper_batch
        )

        early_stopping_wrapper.vn_perm = vn
        early_stopping_wrapper.y_perm = y_batch
        early_stopping_wrapper.helper = helper_batch
        wandb.log({"loss": loss_value})
        if i % 10 == 0:
            print(f"step {i}, loss: {loss_value}")
        if early_stopping_wrapper.callback(params):
            params = early_stopping_wrapper.best_params
            break

    if wandb.config.model["diff"] == "arnet":
        params = {
            k: {"w": v["w"] * helper[k], "b": v["b"]} if k != "rho" else v
            for k, v in params.items()
        }

    return params, loss_value


### ###

## ##
