import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, value_and_grad, vmap
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu
from jax.lax import scan
from jax.lib import xla_bridge
from jax.scipy.special import logsumexp
from jax.scipy.stats import norm
import haiku as hk
from utils.bivariate_copula import norm_copula_logdistribution_logdensity
from utils.sampling import resample_is
import wandb
from jax import lax

# import flows

# todo: add net layers

if wandb.config.data["init"] == "normal":

    def init_marginals_single(y_test):
        d = jnp.shape(y_test)[0]

        # initialize
        logcdf_init_marginals = jnp.zeros(d)
        logpdf_init_marginals = jnp.zeros(d)

        logcdf_init_conditionals = jnp.zeros(d)
        logpdf_init_joints = jnp.ones(d)

        ##normal(0,1)
        mean0 = 0.0
        std0 = 1.0

        logcdf_init_marginals = norm.logcdf(
            y_test, loc=mean0, scale=std0
        )  # marginal initial cdfs
        logpdf_init_marginals = norm.logpdf(
            y_test, loc=mean0, scale=std0
        )  # marginal initial pdfs

        # clip outliers
        eps = 1e-6
        logcdf_init_marginals = jnp.clip(
            logcdf_init_marginals, jnp.log(eps), jnp.log(1 - eps)
        )
        logpdf_init_marginals = jnp.clip(logpdf_init_marginals, jnp.log(eps), jnp.inf)
        ##

        # Joint/conditional from marginals
        logpdf_init_joints = jnp.cumsum(logpdf_init_marginals)
        logcdf_init_conditionals = logcdf_init_marginals

        return logcdf_init_conditionals, logpdf_init_joints


if wandb.config.data["init"] == "uniform":

    def init_marginals_single(y_test):
        d = jnp.shape(y_test)[0]

        # initialize
        logcdf_init_marginals = jnp.zeros(d)
        logpdf_init_marginals = jnp.zeros(d)

        logcdf_init_conditionals = jnp.zeros(d)
        logpdf_init_joints = jnp.ones(d)

        ##normal(0,1)
        # mean0 = 0.0
        # std0 = 1.0

        eps = 1e-6
        logcdf_init_marginals = jnp.log(
            jnp.clip(y_test, eps, None)
        )  # marginal initial cdfs
        logpdf_init_marginals = jnp.zeros(d)  # marginal initial pdfs

        # clip outliers
        logcdf_init_marginals = jnp.clip(
            logcdf_init_marginals, jnp.log(eps), jnp.log(1 - eps)
        )
        logpdf_init_marginals = jnp.clip(logpdf_init_marginals, jnp.log(eps), jnp.inf)
        ##

        # Joint/conditional from marginals
        logpdf_init_joints = jnp.cumsum(logpdf_init_marginals)
        logcdf_init_conditionals = logcdf_init_marginals

        return logcdf_init_conditionals, logpdf_init_joints


if not wandb.config.data["low_mem"] > 3:
    init_marginals = jit(vmap(init_marginals_single, (0)))
else:

    @jit
    def init_marginals(y):
        return lax.map(init_marginals_single, y)


init_marginals_perm_ = jit(vmap(init_marginals, (0)))
init_marginals_perm = jit(vmap(init_marginals_perm_, (0)))

# Compute copula update for a single data point
def update_copula_single(logcdf_conditionals, logpdf_joints, u, v, logalpha, rho_d):
    d = jnp.shape(logpdf_joints)[0]

    logcop_distribution, logcop_dens = norm_copula_logdistribution_logdensity(
        u, v, rho_d
    )  # Eq. (4.5)

    # Calculate product copulas
    logcop_dens_prod = jnp.cumsum(logcop_dens)

    # staggered 1 step to calculate conditional cdfs
    logcop_dens_prod_staggered = jnp.concatenate(
        (jnp.zeros(1), logcop_dens_prod[0 : d - 1])
    )

    log1alpha = jnp.log1p(-jnp.exp(logalpha))

    # update conditional cdfs
    logcdf_conditionals = jnp.logaddexp(
        (log1alpha + logcdf_conditionals),
        (logalpha + logcop_dens_prod_staggered + logcop_distribution),
    ) - jnp.logaddexp(log1alpha, (logalpha + logcop_dens_prod_staggered))
    # ----- until here compute u

    # update density
    # we need this for the prequential likelihood which is the density when optimizing over u
    logpdf_joints = (
        jnp.logaddexp(log1alpha, (logalpha + logcop_dens_prod)) + logpdf_joints
    )

    return logcdf_conditionals, logpdf_joints


# the inputs with 0 are vectorised over the 0th dimension, None nothing
if not wandb.config.data["low_mem"] > 1:
    update_copula = jit(vmap(update_copula_single, (0, 0, 0, None, None, 0)))
else:

    @jit
    def update_copula(logcdf_conditionals, logpdf_joints, u, v, logalpha, rho_d):
        return lax.map(
            lambda x: update_copula_single(x[0], x[1], x[2], v, logalpha, x[3]),
            (logcdf_conditionals, logpdf_joints, u, rho_d,),
        )


# Compute bandwidths for each dimension
if wandb.config.model["diff"] == "extreme":

    @jit
    def compute_rho_d_single(y_plot, y_new, rho, lengths, helper):
        """[summary]

        Args:
            y_plot ([type]): evaluate here
            y_new ([type]): update with this
            rho ([type]): [description]
            lengths ([type]): [description]

        Returns:
            [type]: [description]
        """
        # d = np.shape(y_plot)[0]

        # diff2 = jnp.cumsum(jnp.exp((y_plot - y_new)))  # \ref{eq:rho}
        # # we never use y_d in the update term, so we can ignore it

        # # rescale each dimension to have
        # rho_d = rho * jnp.exp(-diff2)

        # rho_d = rho_d * 0.001 * (
        #     (diff2[-1] > wandb.config.model["extreme_dist"])
        # ) + jnp.clip(rho_d, 0.999, None) * (
        #     (diff2[-1] < wandb.config.model["extreme_dist"])
        # )

        rho_d = jnp.zeros(jnp.shape(y_new)[0]) + (
            (((y_plot - y_new) ** 2).sum() < wandb.config.model["extreme_dist"]) * rho
            + (((y_plot - y_new) ** 2).sum() > 0.8) * rho
        )

        return rho_d


elif wandb.config.model["diff"] == "none":

    @jit
    def compute_rho_d_single(y_plot, y_new, rho, lengths, helper):
        """[summary]

        Args:
            y_plot ([type]): evaluate here
            y_new ([type]): update with this
            rho ([type]): [description]
            lengths ([type]): [description]

        Returns:
            [type]: [description]
        """
        rho_d = jnp.ones(jnp.shape(y_new)[0]) * rho

        return rho_d


elif wandb.config.model["diff"] == "eucl":

    @jit
    def compute_rho_d_single(y_plot, y_new, rho, lengths, helper):
        """[summary]

        Args:
            y_plot ([type]): evaluate here
            y_new ([type]): update with this
            rho ([type]): [description]
            lengths ([type]): [description]

        Returns:
            [type]: [description]
        """
        d = np.shape(y_plot)[0]

        # Compute squared sums for first d-1 dimensions
        # update term is autoregresive, so first dimension does not
        # depend on any other dimensions or scales but only on bandwidth

        diff2 = jnp.cumsum(
            # jnp.exp((y_plot[0 : d - 1] - y_new[0 : d - 1]) / lengths) !
            # jnp.abs((y_plot[0 : d - 1] - y_new[0 : d - 1]) / lengths)
            ((y_plot[0 : d - 1] - y_new[0 : d - 1]) / lengths)
            ** 2
        )  # \ref{eq:rho}
        # we never use y_d in the update term, so we can ignore it

        # rescale each dimension to have
        rho_d = jnp.zeros(jnp.shape(y_new)[0])
        rho_d = rho_d.at[1:].set(rho * jnp.exp(-diff2))

        # Set first dimension bandwidth just to rho_0
        rho_d = rho_d.at[:1].set(rho)
        # rho_d = jnp.where(rho_d < 0.2, 0, rho_d)
        # return rho_d
        return rho_d


elif wandb.config.model["diff"] == "eucl-dim":

    @jit
    def compute_rho_d_single(y_plot, y_new, rho, lengths, helper):
        """[summary]

        Args:
            y_plot ([type]): evaluate here
            y_new ([type]): update with this
            rho ([type]): [description]
            lengths ([type]): [description]

        Returns:
            [type]: [description]
        """
        d = np.shape(y_plot)[0]

        # Compute squared sums for first d-1 dimensions
        # update term is autoregresive, so first dimension does not
        # depend on any other dimensions or scales but only on bandwidth

        diff2 = jnp.cumsum(
            # jnp.exp((y_plot[0 : d - 1] - y_new[0 : d - 1]) / lengths) !
            # jnp.abs((y_plot[0 : d - 1] - y_new[0 : d - 1]) / lengths)
            ((y_plot[0 : d - 1] - y_new[0 : d - 1]) / lengths)
            ** 2
        )  # \ref{eq:rho}
        # we never use y_d in the update term, so we can ignore it

        # rescale each dimension to have
        rho_d = jnp.zeros(jnp.shape(y_new)[0])
        rho_d = rho_d.at[1:].set(rho[1:] * jnp.exp(-diff2))

        # Set first dimension bandwidth just to rho_0
        rho_d = rho_d.at[:1].set(rho[0])
        # rho_d = jnp.where(rho_d < 0.2, 0, rho_d)
        # return rho_d
        return rho_d


elif wandb.config.model["diff"] == "knn":

    @jit
    def compute_rho_d_single(y_plot, y_new, rho, lengths, helper=None):
        """[summary]

        Args:
            y_plot ([type]): evaluate here
            y_new ([type]): update with this
            rho ([type]): [description]
            lengths ([type]): [description]

        Returns:
            [type]: [description]
        """
        d = np.shape(y_plot)[0]

        # Compute squared sums for first d-1 dimensions
        # update term is autoregresive, so first dimension does not
        # depend on any other dimensions or scales but only on bandwidth

        # diff2 = (
        #     jnp.union1d(y_plot, y_new, size=wandb.config.model["n_knn"], fill_value=-1)
        #     == -1
        # ).sum()

        diff2 = (
            wandb.config.model["n_knn"]
            - (
                (y_plot[:, None] @ jnp.ones((1, wandb.config.model["n_knn"])))
                == y_new[None]
            ).sum()
        )

        # we never use y_d in the update term, so we can ignore it

        # rescale each dimension to have
        rho_d = jnp.zeros(wandb.config.data["d"])
        rho_d = rho_d.at[:].set(rho * jnp.exp(-diff2))

        # Set first dimension bandwidth just to rho_0
        # rho_d = rho_d.at[:1].set(rho)
        # rho_d = jnp.where(rho_d < 0.2, 0, rho_d)
        # return rho_d
        # rho_d = jnp.zeros(wandb.config.data["d"]) + (
        #     ((y_plot[..., 0] == y_new).sum() > 0) * 0.9999
        # )

        return rho_d


elif wandb.config.model["diff"] in [
    "net",
    "joint_net",
    "joint_net_zeroed",
    "arnet",
    "net-dim",
    "arnet-dim",
]:

    if (wandb.config.model["diff"] == "net") or (
        wandb.config.model["diff"] == "net-dim"
    ):

        def forward(x):
            mlp = hk.nets.MLP(
                wandb.config.model["net_layers"]
                + [1] * (wandb.config.model["diff"] == "joint_net"),
                activate_final=True,
            )
            return mlp(x)

        network = hk.without_apply_rng(hk.transform(forward))

        if wandb.config.model["perm_while_training"]:

            @jit
            def get_model_output(net_params, x, helper):
                perm, mask = helper
                x_tiled = mask * jnp.tile(
                    jnp.atleast_2d(x.at[perm].set(perm)), (wandb.config.data["d"], 1),
                )  # * this is wrong in some branches
                return network.apply(net_params, x_tiled)

        else:

            @jit
            def get_model_output(net_params, x, helper):
                perm, mask = helper
                x_tiled = mask * jnp.tile(
                    jnp.atleast_2d(x), (wandb.config.data["d"], 1),
                )  # * this is wrong in some branches
                return network.apply(net_params, x_tiled)

        @jit
        def get_diff2(net_params, x_plot, x_new, mask):
            z_plot = get_model_output(net_params, x_plot, mask)
            z_new = get_model_output(net_params, x_new, mask)

            return ((z_plot - z_new) ** 2).sum(axis=-1)[:-1]

        @jit
        def compute_rho_d_single(y_plot, y_new, rho, net_params, mask):
            """[summary]

            Args:
                y_plot ([type]): evaluate here
                y_new ([type]): update with this
                rho ([type]): [description]
                lengths ([type]): [description]

            Returns:
                [type]: [description]
            """

            # Compute squared sums for first d-1 dimensions
            # update term is autoregresive, so first dimension does not
            # depend on any other dimensions or scales but only on bandwidth

            # z_plot = network.apply(net_params, y_plot_tiled)
            # z_new = network.apply(net_params, jnp.atleast_2d(y_new)[..., :-1])

            # diff2 = (
            #     ((z_plot[0] - z_new[0]))
            #     ** 2
            #     # ((z_plot[0, 0 : d - 1] - z_new[0, 0 : d - 1])) ** 2
            # )
            # # we never use y_d in the update term, so we can ignore it

            diff2 = get_diff2(net_params, y_plot, y_new, mask)

            if wandb.config.model["diff"] == "net-dim":
                # rescale each dimension to have
                rho_d = jnp.zeros(jnp.shape(y_new)[0])
                rho_d = rho_d.at[1:].set(rho[1:] * jnp.exp(-diff2))

                # Set first dimension bandwidth just to rho_0
                rho_d = rho_d.at[:1].set(rho[:1])
            else:
                # rescale each dimension to have
                rho_d = jnp.zeros(jnp.shape(y_new)[0])
                rho_d = rho_d.at[1:].set(rho * jnp.exp(-diff2))

                # Set first dimension bandwidth just to rho_0
                rho_d = rho_d.at[:1].set(rho)
            return rho_d

    if (wandb.config.model["diff"] == "arnet") or (
        wandb.config.model["diff"] == "arnet-dim"
    ):

        def forward(x):
            mlp = hk.nets.MLP(
                [l * wandb.config.data["d"] for l in wandb.config.model["net_layers"]],
                activate_final=True,
            )
            return mlp(x)

        network = hk.without_apply_rng(hk.transform(forward))

        @jit
        def get_model_output(net_params, x):
            return network.apply(net_params, x)

        @jit
        def get_diff2(net_params, z_plot, z_new):

            return (
                (
                    z_plot.reshape(
                        (wandb.config.data["d"], wandb.config.model["net_layers"][-1],)
                    )
                    - z_new.reshape(
                        (wandb.config.data["d"], wandb.config.model["net_layers"][-1],)
                    )
                )
                ** 2
            ).sum(axis=-1)

        @jit
        def compute_rho_d_single(
            y_plot, y_new, rho, net_params, mask=None, helper=None
        ):
            """[summary]

            Args:
                y_plot ([type]): evaluate here
                y_new ([type]): update with this
                rho ([type]): [description]
                lengths ([type]): [description]

            Returns:
                [type]: [description]
            """

            # Compute squared sums for first d-1 dimensions
            # update term is autoregresive, so first dimension does not
            # depend on any other dimensions or scales but only on bandwidth

            z_plot = network.apply(net_params, y_plot)
            z_new = network.apply(net_params, y_new)

            # diff2 = (
            #     ((z_plot[0] - z_new[0]))
            #     ** 2
            #     # ((z_plot[0, 0 : d - 1] - z_new[0, 0 : d - 1])) ** 2
            # )
            # # we never use y_d in the update term, so we can ignore it

            diff2 = get_diff2(net_params, z_plot, z_new)

            if wandb.config.model["diff"] == "arnet-dim":
                # rescale each dimension to have
                rho_d = jnp.zeros(jnp.shape(y_new)[0])
                rho_d = rho_d.at[1:].set(rho[1:] * jnp.exp(-diff2[1:]))

                # Set first dimension bandwidth just to rho_0
                rho_d = rho_d.at[:1].set(rho[:1])
            else:
                # rescale each dimension to have
                rho_d = jnp.zeros(jnp.shape(y_new)[0])
                rho_d = rho_d.at[1:].set(rho * jnp.exp(-diff2[1:]))

                # Set first dimension bandwidth just to rho_0
                rho_d = rho_d.at[:1].set(rho)

            return rho_d

    elif wandb.config.model["diff"] == "joint_net":

        @jit
        def get_model_output(net_params, x_plot, x_new):
            x_plot_tiled = jnp.tril(
                jnp.tile(x_plot[..., :-1], (wandb.config.data["d"] - 1, 1))
            )
            x_new_tiled = jnp.tril(
                jnp.tile(x_new[..., :-1], (wandb.config.data["d"] - 1, 1))
            )
            # x_new_tiled = jnp.tril(
            #     jnp.tile(
            #         jnp.atleast_2d(x_new)[..., :-1], (wandb.config.data["d"] - 1, 1)
            #     )
            # )
            x_tiled = jnp.concatenate([x_plot_tiled, x_new_tiled], axis=-1)
            out = network.apply(net_params, x_tiled)
            return jnp.exp(out)

        @jit
        def compute_rho_d_single(
            y_plot, y_new, rho, net_params, mask=None, helper=None
        ):
            """[summary]

            Args:
                y_plot ([type]): evaluate here
                y_new ([type]): update with this
                rho ([type]): [description]
                lengths ([type]): [description]

            Returns:
                [type]: [description]
            """

            # Compute squared sums for first d-1 dimensions
            # update term is autoregresive, so first dimension does not
            # depend on any other dimensions or scales but only on bandwidth
            x_plot_tiled = jnp.tril(
                jnp.tile(y_plot[..., :-1], (wandb.config.data["d"] - 1, 1))
            )
            x_new_tiled = jnp.tril(
                jnp.tile(y_new[..., :-1], (wandb.config.data["d"] - 1, 1))
            )
            x_tiled = jnp.concatenate([x_plot_tiled, x_new_tiled], axis=-1)

            out = network.apply(net_params, x_tiled)
            # out = (x_tiled - x_tiled_2) ** 2

            diff2 = out.sum(1)
            # z_diff = jnp.zeros_like(y_plot)[..., :-1]

            # diff2 = jnp.clip(z_diff, 1e-10, 1e10)
            # we never use y_d in the update term, so we can ignore it

            # rescale each dimension to have
            rho_d = jnp.zeros(jnp.shape(y_new)[0])
            rho_d = rho_d.at[1:].set(rho * jnp.exp(-diff2).squeeze())

            # Set first dimension bandwidth just to rho_0
            rho_d = rho_d.at[:1].set(rho)
            return rho_d

    elif wandb.config.model["diff"] == "joint_net_zeroed":

        @jit
        def compute_rho_d_single(y_plot, y_new, rho, net_params, helper=None):
            """[summary]

            Args:
                y_plot ([type]): evaluate here
                y_new ([type]): update with this
                rho ([type]): [description]
                lengths ([type]): [description]

            Returns:
                [type]: [description]
            """

            # Compute squared sums for first d-1 dimensions
            # update term is autoregresive, so first dimension does not
            # depend on any other dimensions or scales but only on bandwidth

            # net_input_0 = jnp.tril(
            #     jnp.repeat(jnp.atleast_2d(y_plot), wandb.config.data.d - 1, 0)
            # )
            # net_input_1 = jnp.tril(
            #     jnp.repeat(jnp.atleast_2d(y_new), wandb.config.data.d - 1, 0)
            # )

            # diff2 = jnp.exp(z_diff)
            # we never use y_d in the update term, so we can ignore it
            z_diff = network.apply(
                net_params,
                jnp.concatenate(
                    (jnp.atleast_2d(y_plot)[..., :-1], jnp.atleast_2d(y_new)[..., :-1]),
                    1,
                ),
            )

            # rescale each dimension to have
            rho_d = jnp.zeros(jnp.shape(y_new)[0])
            rho_d = rho_d.at[1:].set(jax.nn.sigmoid(z_diff).reshape((-1)))

            # Set first dimension bandwidth just to rho_0
            rho_d = rho_d.at[0].set(rho)
            return rho_d


elif wandb.config.model["diff"] == "nice_net":
    num_epochs, batch_size = 100, 100

    # %%
    def get_masks(input_dim, hidden_dim=64, num_hidden=1):
        masks = []
        input_degrees = np.arange(input_dim)
        degrees = [input_degrees]

        for n_h in range(num_hidden + 1):
            degrees += [np.arange(hidden_dim) % (input_dim - 1)]
        degrees += [input_degrees % input_dim - 1]

        for (d0, d1) in zip(degrees[:-1], degrees[1:]):
            masks += [
                np.transpose(np.expand_dims(d1, -1) >= np.expand_dims(d0, 0)).astype(
                    np.float32
                )
            ]
        return masks

    def masked_transform(rng, input_dim):
        masks = get_masks(input_dim, hidden_dim=64, num_hidden=1)
        act = stax.Relu
        init_fun, apply_fun = stax.serial(
            flows.MaskedDense(masks[0]),
            act,
            flows.MaskedDense(masks[1]),
            act,
            flows.MaskedDense(masks[2].tile(2)),
        )
        _, params = init_fun(rng, (input_dim,))
        return params, apply_fun

    network = flows.Flow(
        flows.Serial(*(flows.MADE(masked_transform), flows.Reverse()) * 5),
        flows.Laplace(),
    )


elif wandb.config.model["diff"] == "dim":

    def compute_rho_d_single(y, y_new, rho, lengths, helper=None):  # TODO
        return jnp.concatenate((rho[None], lengths))

    compute_rho_d = jit(vmap(compute_rho_d_single, (0, None, None, None, None)))

else:
    raise NotImplementedError


if wandb.config.model["diff"] == "none":

    @jit
    def compute_rho_d(y_plot, y_new, rho, lengths, helper):
        return jnp.ones_like(y_plot) * rho


elif not wandb.config.data["low_mem"] > 2:
    compute_rho_d = jit(vmap(compute_rho_d_single, (0, None, None, None, None)))

else:

    @jit
    def compute_rho_d(y_plot, y_new, rho, lengths, helper=None):
        return lax.map(
            lambda x: compute_rho_d_single(x, y_new, rho, lengths, helper), y_plot
        )

