import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, value_and_grad, vmap
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu
from jax.lax import scan
from jax.lib import xla_bridge
from jax.scipy.special import logsumexp
from jax.scipy.stats import norm
import haiku as hk
from utils.bivariate_copula import norm_copula_logdistribution_logdensity
from utils.sampling import resample_is
import wandb
from .copula_ar_base import *
from .copula_ar_train import *
from .copula_ar_test import update_ptest_loop_perm
from jax.random import PRNGKey, permutation, split
import jax.lax as lax
import optax
import math
from tqdm import tqdm
from utils import progress_bar_scan

# Compute permutation-averaged preq loglik
# * add n_eval if OOM

# region helper
if "net" in wandb.config.model["diff"]:

    @jit
    def get_rho_params(hyperparam):
        return (
            1
            / (
                1 + jnp.exp(hyperparam["rho"].reshape((hyperparam["rho"].shape[0], -1)))
            ),
            {k: v for k, v in hyperparam.items() if k != "rho"},
        )

    @jit
    def get_rho_params_wo_transform(hyperparam):
        return (
            hyperparam["rho"].reshape((hyperparam["rho"].shape[0], -1)),
            {k: v for k, v in hyperparam.items() if k != "rho"},
        )

    @jit
    def slice_lengths(lengths, slice_idx):
        return {
            k: {"w": v["w"][slice_idx], "b": v["b"][slice_idx]}
            for k, v in lengths.items()
        }


elif wandb.config.model["diff"] == "dim":

    def get_rho_params(hyperparam):
        return (
            1 / (1 + jnp.exp(hyperparam[..., 0])),
            1 / (1 + jnp.exp(hyperparam[..., 1:])),
        )

    def get_rho_params_wo_transform(hyperparam):
        return hyperparam[..., 0], hyperparam[..., 1:]

    @jit
    def slice_lengths(lengths, slice_idx):
        return lengths[slice_idx]


elif wandb.config.model["diff"] == "eucl-dim":

    @jit
    def get_rho_params(hyperparam):
        return (
            1 / (1 + jnp.exp(hyperparam[..., : wandb.config.data["d"]])),
            jnp.exp(hyperparam[..., wandb.config.data["d"] :]),
        )

    @jit
    def get_rho_params_wo_transform(hyperparam):
        return (
            hyperparam[..., : wandb.config.data["d"]],
            hyperparam[..., wandb.config.data["d"] :],
        )

    @jit
    def slice_lengths(lengths, slice_idx):
        return lengths[slice_idx]


else:

    @jit
    def get_rho_params(hyperparam):
        return 1 / (1 + jnp.exp(hyperparam[..., :1])), jnp.exp(hyperparam[..., 1:])

    @jit
    def get_rho_params_wo_transform(hyperparam):
        return hyperparam[..., :1], hyperparam[..., 1:]

    @jit
    def slice_lengths(lengths, slice_idx):
        return lengths[slice_idx]


# endregion


# region ori optim
# when scipy opt we cannot return more than a scalar


@jit
def negpreq_jointloglik_perm(hyperparam, y_perm, d_perm_inds, n_perm_inds, helper=None):
    # if not nettraining:
    #     rho = 1 / (1 + jnp.exp(hyperparam[0]))  # force 0 <rho<1
    #     lengths = jnp.exp(hyperparam[1:])
    # else:
    if wandb.config.model["diff"] in ["eucl", "none", "extreme", "dim", "eucl-dim"]:
        hyperparam = hyperparam.reshape((len(d_perm_inds), -1))
    rho, lengths = get_rho_params(hyperparam)

    n = jnp.shape(n_perm_inds)[-1]

    # For non autoregressive models, set l_scale -> jnp.inf?

    (vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik,) = update_pn_loop_perm(
        rho, lengths, y_perm, d_perm_inds, n_perm_inds, helper
    )
    if wandb.config.model["min_logpdf_joints"]:
        preq_loglik = logpdf_joints_yn  # * seems like extreme overfitting

    # Average over permutations
    preq_loglik = jnp.mean(preq_loglik, axis=0)
    preq_loglik = jnp.mean(preq_loglik, axis=0)

    # Marg
    preq_jointloglik = jnp.sum(preq_loglik[:, -1])  # only look at joint pdf

    if wandb.config.model["scipy_opt"]:
        return -preq_jointloglik / n
    else:
        return -preq_jointloglik / n, (vn)


# Compute derivatives wrt hyperparameters
if wandb.config.model["scipy_opt"]:
    grad_jll_perm = jit(grad(negpreq_jointloglik_perm, has_aux=False))

else:
    grad_jll_perm = jit(grad(negpreq_jointloglik_perm, has_aux=True))


fun_grad_jll_perm = jit(value_and_grad(negpreq_jointloglik_perm))

# Functions for scipy (convert to numpy array for scipy.optimize)
def fun_jll_perm_sp(hyperparam, z, d_perm_inds, n_perm_inds, helper=None):
    return np.asarray(
        negpreq_jointloglik_perm(hyperparam, z, d_perm_inds, n_perm_inds, helper)
    )


def grad_jll_perm_sp(hyperparam, z, d_perm_inds, n_perm_inds, helper=None):
    return np.asarray(grad_jll_perm(hyperparam, z, d_perm_inds, n_perm_inds, helper))


# endregion


# region Adam fit
def fit(
    y,
    params: optax.Params,
    optimizer: optax.GradientTransformation,
    key_seq,
    n_optim,
    maxiter,
    early_stopping_wrapper,
    helper=None,
    d_perm_iter=0,
) -> optax.Params:
    opt_state = optimizer.init(params)
    y_val_ori = early_stopping_wrapper.y_val

    @jax.jit
    def step(params, opt_state, y_batch, helper_batch):
        (loss_value, (vn)), grads = jax.value_and_grad(
            negpreq_jointloglik_perm, has_aux=True,
        )(
            params,
            y_batch,
            jnp.arange(y_batch.shape[-1])[None],
            jnp.arange(y_batch.shape[-2])[None],
            helper_batch,
        )
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value, vn

    loss_value = -999
    for i in range(maxiter):
        y_batch = permutation(next(key_seq), y)[0:n_optim]
        if wandb.config.model["diff"] == "arnet":
            params = {
                k: {"w": v["w"] * helper[k], "b": v["b"]} if k != "rho" else v
                for k, v in params.items()
            }
            helper_batch = None
        elif (
            wandb.config.model["diff"] == "net"
            and wandb.config.model["perm_while_training"]
        ):
            perm = jax.random.permutation(next(key_seq), wandb.config.data["d"])
            helper_batch = (perm, helper[1][:, perm])
            y_batch = y_batch[:, perm]
            early_stopping_wrapper.y_val = y_val_ori[:, perm]
        else:
            helper_batch = helper
        # dperm_subkey = next(key_seq)[None]
        # y_batch = vmap(permutation, (0, None, None))(dperm_subkey, y_batch, -1)
        params, opt_state, loss_value, vn = step(
            params, opt_state, y_batch, helper_batch
        )

        early_stopping_wrapper.vn_perm = vn
        early_stopping_wrapper.y_perm = y_batch
        early_stopping_wrapper.helper = helper_batch
        wandb.log({"loss": loss_value})
        if i % 50 == 0:
            print(f"step {i}, loss: {loss_value}")
        if early_stopping_wrapper.callback(params):
            params = early_stopping_wrapper.best_params
            break

    if wandb.config.model["diff"] == "arnet":
        params = {
            k: {"w": v["w"] * helper[k], "b": v["b"]} if k != "rho" else v
            for k, v in params.items()
        }

    return params, loss_value


# endregion


# region batched optimisation
@jit
def update_pn_loop_from_checkpoint_optim(
    rho,
    lengths,
    y,
    vn,
    logcdf_conditionals_yn,
    logpdf_joints_yn,
    preq_loglik,
    index,
    helper,
):

    # carry out scan over y_{1:n}
    carry = (
        vn,
        logcdf_conditionals_yn,
        logpdf_joints_yn,
        preq_loglik,
        rho,
        lengths,
        y,
        index,
        helper,
    )

    # for loop over i=1:n, update p
    rng = jnp.arange(wandb.config.data["batch_size_optim"])
    carry, rng = update_pn_scan(carry, rng)

    vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik, *_ = carry

    return vn, logcdf_conditionals_yn, logpdf_joints_yn, preq_loglik


update_pn_loop_from_checkpoint_perm_optim_ = jit(
    vmap(update_pn_loop_from_checkpoint_optim, (None, None, 0, 0, 0, 0, 0, None))
)
# if wandb.config.model["diff"] not in ["eucl"]:

#     update_pn_loop_from_checkpoint_perm_optim = jit(
#         vmap(
#             update_pn_loop_from_checkpoint_perm_optim_,
#             (None, None, 0, 0, 0, 0, 0, None),
#         )
#     )
# else:

update_pn_loop_from_checkpoint_perm_optim = jit(
    vmap(update_pn_loop_from_checkpoint_perm_optim_, (0, 0, 0, 0, 0, 0, 0, None))
)


@jit
def update_pn_from_checkpoint_optim(
    params,
    y_perm_batch,
    vn,
    logcdf_conditionals,
    logpdf_joints,
    preq_loglik_save,
    i,
    helper,
):

    rho, lengths = get_rho_params(params)

    vn_, _, _, preq_loglik_ = update_pn_loop_from_checkpoint_perm_optim(
        rho,
        lengths,
        y_perm_batch,
        vn,
        logcdf_conditionals,
        logpdf_joints,
        preq_loglik_save,
        i * wandb.config.data["batch_size_optim"],
        helper,
    )

    loss_ = jnp.mean(preq_loglik_, axis=0)
    loss = -jnp.mean(loss_, axis=0)[:, -1].mean()

    return loss, (vn_, preq_loglik_)


@jit
def grad_values_pn_checkpoint(
    params, y_perm, vn_, preq_loglik_, i, vn_perm, preq_loglik_perm
):
    rho, lengths = get_rho_params(params)
    vn = jnp.zeros_like(y_perm)[
        :, :, 0,
    ]
    preq_loglik_save = vn[..., :2]
    a, b, _, _, d = y_perm.shape

    vn_perm = vn_perm.at[:, :, i - 1].set(vn_)
    preq_loglik_perm = preq_loglik_perm.at[:, :, i - 1].set(preq_loglik_)

    logcdf_conditionals, logpdf_joints = update_ptest_loop_perm(
        vn_perm.reshape((a, b, -1, d)),  # * was [:, :, :i]
        rho,
        lengths,
        y_perm.reshape((a, b, -1, d)),  # * was [:, :, :i]
        y_perm[:, :, i],
    )
    ((loss_value, (vn_, preq_loglik_)), grads,) = jit(
        jax.value_and_grad(update_pn_from_checkpoint_optim, has_aux=True)
    )(
        params,
        y_perm[:, :, i],
        vn,
        logcdf_conditionals,
        logpdf_joints,
        preq_loglik_save,
        i,
    )

    return ((loss_value, (vn_, preq_loglik_, vn_perm, preq_loglik_perm)), grads)


@jit
def update_pn_full_optim(params, y_perm):

    rho, lengths = get_rho_params(params)
    vn_, _, _, preq_loglik_ = update_pn_loop_perm_full(
        rho,
        lengths,
        y_perm,
        jnp.arange(y_perm.shape[-1])[None],
        jnp.arange(y_perm.shape[-2])[None],
    )

    loss_ = jnp.mean(preq_loglik_, axis=0)
    loss = -jnp.mean(loss_, axis=0)[:, -1].mean()

    return loss, (vn_, preq_loglik_)


# neural net fit
def fit_batches(
    y,
    params: optax.Params,
    optimizer: optax.GradientTransformation,
    key_seq,
    n_optim,
    maxiter,
    early_stopping_wrapper,
    masks=None,
    d_perm_i=0,
) -> optax.Params:
    opt_state = optimizer.init(params)
    this_iter = 0

    # * use device_put to keep aray on GPU
    assert len(y.shape) == 2

    @jit
    def batches_step_init(params, opt_state, y_perm):
        ((loss_value, (vn_, preq_loglik_)), grads,) = jax.value_and_grad(
            update_pn_full_optim, has_aux=True
        )(params, y_perm)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value, vn_, preq_loglik_

    @jit
    def batches_step_update(
        params, opt_state, y_perm, vn_, preq_loglik_, i, vn_perm, preq_loglik_perm,
    ):

        (
            (loss_value, (vn_, preq_loglik_, vn_perm, preq_loglik_perm)),
            grads,
        ) = grad_values_pn_checkpoint(
            params, y_perm, vn_, preq_loglik_, i, vn_perm, preq_loglik_perm,
        )
        updates, opt_state = jit(optimizer.update)(grads, opt_state, params)
        params = jit(optax.apply_updates)(params, updates)
        return (
            params,
            opt_state,
            loss_value,
            (vn_, preq_loglik_, vn_perm, preq_loglik_perm),
        )

    def update_pn_loop_perm_batches_optim(
        params, opt_state, y_perm, optimizer, this_iter, maxiter, early_stopping_wrapper
    ):

        a, b, c, d = y_perm.shape
        y_perm = jnp.reshape(
            y_perm,
            (
                a,
                b,
                c // wandb.config.data["batch_size_optim"],
                wandb.config.data["batch_size_optim"],
                d,
            ),
        )
        vn_perm = jnp.zeros_like(y_perm)
        preq_loglik_perm = jnp.zeros_like(y_perm)[..., :2]

        # carry out scan over y_{1:n}

        params, opt_state, loss_value, vn_, preq_loglik_ = batches_step_init(
            params, opt_state, y_perm[:, :, 0]
        )
        this_iter += 1
        wandb.log({f"loss": loss_value})
        early_stopping_wrapper.vn_perm = vn_
        early_stopping_wrapper.y_perm = y_perm[:, :, :1].reshape((a, b, -1, d))
        if early_stopping_wrapper.callback(params):
            params = early_stopping_wrapper.best_params
            return params, opt_state, 1

        for i in tqdm(
            range(1, max(c // wandb.config.data["batch_size_optim"], this_iter))
        ):  #!
            (
                params,
                opt_state,
                loss_value,
                (vn_, preq_loglik_, vn_perm, preq_loglik_perm),
            ) = batches_step_update(
                params,
                opt_state,
                y_perm,
                vn_,
                preq_loglik_,
                i,
                vn_perm,
                preq_loglik_perm,
            )

            this_iter += 1
            wandb.log({f"loss": loss_value})
            if this_iter % 10 == 0:
                print(f"d:{d_perm_i} step:{this_iter}, loss:{loss_value}")

            early_stopping_wrapper.vn_perm = vn_perm[:, :, :i].reshape((a, b, -1, d))
            early_stopping_wrapper.y_perm = y_perm[:, :, :i].reshape((a, b, -1, d))
            if early_stopping_wrapper.callback(params):
                params = early_stopping_wrapper.best_params
                break

            if this_iter >= maxiter:
                break

        return params, opt_state, this_iter

    while this_iter < maxiter and early_stopping_wrapper.early_stop is False:
        if masks is not None:
            params = {
                k: {"w": v["w"] * masks[k], "b": v["b"]} if k != "rho" else v
                for k, v in params.items()
            }
        y_batch = vmap(permutation, (0, None))(next(key_seq)[None], y)[:, 0:n_optim][
            None
        ]
        # dperm_subkey = next(key_seq)[None]
        # y_batch = vmap(permutation, (0, None, None))(dperm_subkey, y_batch, -1)
        # if wandb.config.model["diff"] in ["eucl"]:
        #     params = vmap(permutation, (1, None, None))(
        #         dperm_subkey, params, -1
        #     )
        #     raise NotImplementedError
        params, opt_state, this_iter = update_pn_loop_perm_batches_optim(
            params,
            opt_state,
            y_batch,
            optimizer,
            this_iter,
            maxiter,
            early_stopping_wrapper,
        )

    if masks is not None:
        params = {
            k: {"w": v["w"] * masks[k], "b": v["b"]} if k != "rho" else v
            for k, v in params.items()
        }
    return params, wandb.run.summary[f"loss"]


# endregion
