import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax.scipy.special import logsumexp
from jax.lax import scan
import jax
import optax
from jax.random import PRNGKey, permutation, split

from tqdm import tqdm_notebook

from . import (
    compute_rho_d,
    compute_rho_d_single,
    get_rho_params,
    get_rho_params_wo_transform,
)
import wandb

from .. import copula_AR_functions as mvcd


### Utility functions ###

# Vmap over rho/alpha as well as they depend on x
update_copula = jit(vmap(mvcd.update_copula_single, (0, 0, 0, None, 0, 0)))

# Compute log k_xx for a single data point
@jit
def calc_logkxx_single(x, x_new, rho_x):
    logk_xx = -0.5 * jnp.sum(jnp.log(1 - rho_x ** 2)) - jnp.sum(
        (0.5 / (1 - rho_x ** 2))
        * (((rho_x ** 2) * (x ** 2 + x_new ** 2) - 2 * rho_x * x * x_new))
    )
    return logk_xx


calc_logkxx = jit(vmap(calc_logkxx_single, (0, None, 0)))
calc_logkxx_test = jit(vmap(calc_logkxx, (None, 0, None)))
### ###

### Functions to calculate overhead v_{1:n} ###

# Compute v_i for a single datum
@jit
def update_pn(carry, i):
    vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik, x, rho, lengths = carry

    # Compute new x
    x_new = x[i]
    logalpha = jnp.log(2.0 - (1 / (i + 1))) - jnp.log(i + 2)

    rhos = compute_rho_d(
        jnp.concatenate((x, jnp.zeros((len(x), 1))), -1),
        jnp.concatenate((x_new, jnp.zeros((1))), 0),
        rho,
        lengths,
        None,
    )
    rho_d, rho_x = rhos[..., -1], rhos[..., :-1]
    # compute x rhos/alphas
    logk_xx = calc_logkxx(x, x_new, rho_x)
    logalphak_xx = logalpha + logk_xx
    log1alpha = jnp.log1p(-jnp.exp(logalpha))
    logalpha_x = (logalphak_xx) - (
        jnp.logaddexp(log1alpha, logalphak_xx)
    )  # alpha*k_xx /(1-alpha + alpha*k_xx)

    # clip for numerical stability to prevent NaNs
    eps = 1e-5  # 1e-6 causes optimization to fail
    logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps))

    u = jnp.exp(logcdf_conditionals_yn)
    v = jnp.exp(logcdf_conditionals_yn[i])

    vn = vn.at[i].set(v)  # remember history of vn

    preq_loglik = preq_loglik.at[i].set(logpdf_joints_yn[i, -1])
    logcdf_conditionals_yn, logpdf_joints_yn = update_copula(
        logcdf_conditionals_yn, logpdf_joints_yn, u, v, logalpha_x, rho_d
    )
    carry = vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik, x, rho, lengths
    return carry, i


# Scan over y_{1:n}
@jit
def update_pn_scan(carry, rng):
    return scan(update_pn, carry, rng)


# Compute v_{1:n}
@jit
def update_pn_loop(rho, rho_x, y_cat, d_perm_inds, n_perm_inds, helper=None):
    n = len(n_perm_inds)  #!!!!
    x, y = y_cat[..., :-1], y_cat[..., -1:]
    x = jnp.take(jnp.take(x, n_perm_inds, axis=0), d_perm_inds[..., :-1], axis=-1)
    y = jnp.take(y, n_perm_inds, axis=0)

    preq_loglik = jnp.zeros((n, 1))  # prequential joint loglik for y | x
    vn = jnp.zeros(
        (n, 1)
    )  # conditional cdf history of yn, no need to differentiate wrt

    # initialize cdf/pdf
    logcdf_conditionals_yn, logpdf_joints_yn = mvcd.init_marginals(y)

    carry = vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik, x, rho, rho_x
    rng = jnp.arange(n)
    carry, rng = update_pn_scan(carry, rng)

    vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik, *_ = carry

    return vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik


### ###
update_pn_loop_perm_ = jit(vmap(update_pn_loop, (None, None, None, None, 0, None)))
update_pn_loop_perm = jit(vmap(update_pn_loop_perm_, (0, 0, None, 0, None, None)))

# update_pn_loop_perm_ = jit(
#     vmap(update_pn_loop, (None, None, None, None, None, 0, None))
# )
# update_pn_loop_perm = jit(vmap(update_pn_loop_perm_, (0, 0, None, None, 0, None, None)))


### Functions for Optimizing prequential log likelihood ###

# Compute permutation-averaged conditional preq loglik
@jit
def negpreq_jointloglik_perm(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    if wandb.config.model["diff"] in ["eucl", "none", "extreme", "dim", "eucl-dim"]:  #!
        hyperparam = hyperparam.reshape((len(d_perm_inds), -1))

    rho, rho_x = get_rho_params(hyperparam)

    n = len(n_perm_inds[0])

    # Compute prequential loglik
    log_vn, _, _, preq_loglik = update_pn_loop_perm(
        rho, rho_x, y_perm, d_perm_inds, n_perm_inds, helper
    )
    # Average over permutations
    preq_loglik = jnp.mean(preq_loglik, axis=0)
    preq_loglik = jnp.mean(preq_loglik, axis=0)

    # Sum prequential terms
    preq_jointloglik = jnp.sum(preq_loglik[:, -1])  # only look at joint pdf
    if wandb.config.model["scipy_opt"]:
        return -preq_jointloglik / n
    else:
        return -preq_jointloglik / n, (log_vn)


# Compute derivatives wrt hyperparameters
fun_grad_jll_perm = jit(value_and_grad(negpreq_jointloglik_perm))
grad_jll_perm = jit(grad(negpreq_jointloglik_perm))

# Functions for scipy (convert to numpy array)
def fun_jll_perm_sp(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    return np.array(
        negpreq_jointloglik_perm(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper)
    )


def grad_jll_perm_sp(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    return np.array(
        grad_jll_perm(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper)
    )  ####


def fun_grad_jll_perm_sp(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    value, grad = fun_grad_jll_perm(
        hyperparam, y_perm, d_perm_inds, n_perm_inds, helper
    )
    return (np.array(value), np.array(grad))


### Functions for computing p(y|x) on test points ###

# Update p(y|x) for a single test point and observed datum
@jit
def update_ptest_single(carry, i):
    vn, logcdf_conditionals_ytest, logpdf_joints_ytest, x, x_test, rho, lengths = carry

    x_new = x[i]
    logalpha = jnp.log(2.0 - (1 / (i + 1))) - jnp.log(i + 2)
    rhos = compute_rho_d_single(
        jnp.concatenate((x_new, jnp.zeros((1))), -1),
        jnp.concatenate((x_test, jnp.zeros((1))), -1),
        rho,
        lengths,
        None,
    )
    rho_d, rho_x = rhos[..., -1], rhos[..., :-1]

    # compute x rhos/alphas
    logk_xx = calc_logkxx_single(x_test, x_new, rho_x)

    logalphak_xx = logalpha + logk_xx
    log1alpha = jnp.log1p(-jnp.exp(logalpha))
    logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha, logalphak_xx))

    # clip for numerical stability to prevent NaNs
    eps = 1e-5  # 1e-6 causes optimization to fail
    logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps))

    u = jnp.exp(logcdf_conditionals_ytest)
    v = vn[i]

    logcdf_conditionals_ytest, logpdf_joints_ytest = mvcd.update_copula_single(
        logcdf_conditionals_ytest, logpdf_joints_ytest, u, v, logalpha_x, rho_d
    )

    carry = vn, logcdf_conditionals_ytest, logpdf_joints_ytest, x, x_test, rho, lengths
    return carry, i


@jit
def update_ptest_single_scan(carry, rng):
    return scan(update_ptest_single, carry, rng)


@jit
def update_ptest_single_loop(  #!
    vn, rho, rho_x, y_cat, d_perm_inds, n_perm_inds, y_test_cat
):
    x, y = y_cat[..., :-1], y_cat[..., -1:]

    n = jnp.shape(y)[0]
    x_test, y_test = y_test_cat[..., :-1], y_test_cat[..., -1:]
    x_test = x_test[..., d_perm_inds[..., :-1]]
    # y = y[..., d_perm_inds]
    y = jnp.take(y, n_perm_inds, axis=0)
    x = x[..., d_perm_inds[..., :-1]]
    x = jnp.take(x, n_perm_inds, axis=0)

    logcdf_conditionals_ytest, logpdf_joints_ytest = mvcd.init_marginals_single(y_test)

    carry = vn, logcdf_conditionals_ytest, logpdf_joints_ytest, x, x_test, rho, rho_x
    rng = jnp.arange(n)
    carry, rng = update_ptest_single_scan(carry, rng)
    vn, logcdf_conditionals_ytest, logpdf_joints_ytest, x, x_test, rho, rho_x = carry

    return logcdf_conditionals_ytest, logpdf_joints_ytest


update_ptest_single_loop_perm_ = jit(
    vmap(update_ptest_single_loop, (0, None, None, None, None, 0, None))
)  # vmap over vn_perm

update_ptest_single_loop_perm = jit(
    vmap(update_ptest_single_loop_perm_, (0, 0, 0, None, 0, None, None))
)  # vmap over vn_perm

# Average p(y|x) over permutations
@jit
def update_ptest_single_loop_perm_av(
    vn_perm, rho, rho_x, y_perm, x_test, d_perm_inds, n_perm_inds, index, helper,
):
    n_perm = len(n_perm_inds)
    d_perm = len(d_perm_inds)
    logcdf_conditionals, logpdf_joints = update_ptest_single_loop_perm(
        vn_perm, rho, rho_x, y_perm, d_perm_inds, n_perm_inds, x_test
    )
    logcdf_conditionals = logsumexp(logcdf_conditionals, axis=0) - jnp.log(d_perm)
    logpdf_joints = logsumexp(logpdf_joints, axis=0) - jnp.log(d_perm)
    logcdf_conditionals = logsumexp(logcdf_conditionals, axis=0) - jnp.log(n_perm)
    logpdf_joints = logsumexp(logpdf_joints, axis=0) - jnp.log(n_perm)
    return logcdf_conditionals, logpdf_joints


# Vmap over multiple test points
update_ptest_loop_perm_av = jit(
    vmap(
        update_ptest_single_loop_perm_av,
        (None, None, None, None, 0, None, None, None, None),
    )
)


def fit(
    y,
    params: optax.Params,
    optimizer: optax.GradientTransformation,
    key_seq,
    n_optim,
    maxiter,
    early_stopping_wrapper,
    helper=None,
    d_perm_iter=0,
) -> optax.Params:
    opt_state = optimizer.init(params)
    y_val_ori = early_stopping_wrapper.y_val

    @jax.jit
    def step(params, opt_state, y_batch, helper_batch):
        (loss_value, (vn)), grads = jax.value_and_grad(
            negpreq_jointloglik_perm, has_aux=True,
        )(
            params,
            y_batch,
            jnp.arange(y_batch.shape[-1])[None],
            jnp.arange(y_batch.shape[-2])[None],
            helper_batch,
        )
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value, vn

    loss_value = -999
    for i in range(maxiter):
        y_batch = permutation(next(key_seq), y)[0:n_optim]
        if wandb.config.model["diff"] == "arnet":
            params = {
                k: {"w": v["w"] * helper[k], "b": v["b"]} if k != "rho" else v
                for k, v in params.items()
            }
            helper_batch = None
        elif (
            wandb.config.model["diff"] == "net"
            and wandb.config.model["perm_while_training"]
        ):
            perm = jax.random.permutation(next(key_seq), wandb.config.data["d"])
            helper_batch = (perm, helper[1][:, perm])
            y_batch = y_batch[:, perm]
            early_stopping_wrapper.y_val = y_val_ori[:, perm]
        else:
            helper_batch = helper
        # dperm_subkey = next(key_seq)[None]
        # y_batch = vmap(permutation, (0, None, None))(dperm_subkey, y_batch, -1)
        params, opt_state, loss_value, vn = step(
            params, opt_state, y_batch, helper_batch
        )

        early_stopping_wrapper.vn_perm = vn
        early_stopping_wrapper.y_perm = y_batch
        early_stopping_wrapper.helper = helper_batch
        wandb.log({"loss": loss_value})
        if i % 10 == 0:
            print(f"step {i}, loss: {loss_value}")
        if early_stopping_wrapper.callback(params):
            params = early_stopping_wrapper.best_params
            break

    if wandb.config.model["diff"] == "arnet":
        params = {
            k: {"w": v["w"] * helper[k], "b": v["b"]} if k != "rho" else v
            for k, v in params.items()
        }

    return params, loss_value


### ###

## ##
