import jax.numpy as jnp
from jax import custom_jvp, jit
from jax.scipy.special import ndtri
from jax.scipy.stats import norm
import haiku as hk
import jax

import wandb

### Functions for normal copula ###
# Custom derivatives for Phi^{-1} to speed up autograd
@custom_jvp
def ndtri_(u):
    return ndtri(u)


@ndtri_.defjvp
def f_jvp(primals, tangents):
    (u,) = primals
    (u_dot,) = tangents
    primal_out = ndtri_(u)
    tangent_out = (1 / norm.pdf(primal_out)) * u_dot
    return primal_out, tangent_out


ndtri_ = jit(ndtri_)

# Custom derivatives for logPhi to speed up autograd
@custom_jvp
def norm_logcdf(z):
    return norm.logcdf(z)


@norm_logcdf.defjvp
def f_jvp(primals, tangents):
    (z,) = primals
    (z_dot,) = tangents
    primal_out = norm_logcdf(z)
    tangent_out = jnp.exp(norm.logpdf(z) - primal_out) * z_dot
    return primal_out, tangent_out


norm_logcdf = jit(norm_logcdf)

if wandb.config.model["copula_type"] == "gaussian":

    @jit  # Calculate bivariate normal copula log H_uv and log c_uv
    def norm_copula_logdistribution_logdensity(u, v, rho):
        # clip to prevent 0s and 1s in CDF, needed for numerical stability in high d.
        eps = 1e-6
        u = jnp.clip(u, eps, 1 - eps)
        v = jnp.clip(v, eps, 1 - eps)

        # for reverse mode
        pu = ndtri_(u)
        pv = ndtri_(v)

        z = (pu - rho * pv) / jnp.sqrt(1 - rho ** 2)
        logcop_dist = norm_logcdf(z)  #! RHS of Eq. 4.5 log Hp(u, v)
        logcop_dist = jnp.clip(logcop_dist, jnp.log(eps), jnp.log(1 - eps))
        logcop_dens = -0.5 * jnp.log(1 - rho ** 2) + (0.5 / (1 - rho ** 2)) * (
            -(rho ** 2) * (pu ** 2 + pv ** 2) + 2 * rho * pu * pv
        )  #! LHS of Eq 4.5 log c(u, v)

        return logcop_dist, logcop_dens


elif wandb.config.model["copula_type"] == "bernoulli":

    @jit  # Calculate bivariate bernoulli copula log F_uv and log d_uv
    def bern_copula_logdistribution_logdensity(log_v, logpmf1, y_new, rho):
        eps = 5e-5
        logpmf1 = jnp.clip(
            logpmf1, jnp.log(eps), jnp.log(1 - eps)
        )  # clip u before passing to bicop
        log_v = jnp.clip(
            log_v, jnp.log(eps), jnp.log(1 - eps)
        )  # clip u before passing to bicop

        log1_v = jnp.log1p(jnp.clip(-jnp.exp(log_v), -1 + eps, jnp.inf))

        min_logu1v1 = jnp.min(jnp.array([logpmf1, log_v]))

        # Bernoulli copula
        frac0 = 1 / jnp.exp(log1_v) - jnp.exp(min_logu1v1 - logpmf1 - log1_v)
        frac1 = jnp.exp(min_logu1v1 - logpmf1 - log_v)
        frac = y_new * frac0 + (1 - y_new) * frac1  # make this more accurate?
        kyy_ = 1 - rho + rho * frac
        kyy_ = jnp.clip(kyy_, eps, jnp.inf)

        logcop_dens = jnp.log(kyy_)
        logcop_dist = jnp.log(
            (1 - rho + rho * frac0) + (1 - y_new) * (1 - rho + rho * frac1)
        )

        return logcop_dist, logcop_dens


elif wandb.config.model["copula_type"] == "mlp":

    # has to be symmetric
    # c_n  has to converge to 1 for large n (that is the copula converges to the indepedent copula as the sample size increases)
    # indepedence copula distribution C(u, v) = u*v

    class NetModuleU(hk.Module):
        def __init__(
            self, output_sizes=[32, int(wandb.config.data["d"]),], name="custom_linear",
        ):
            super().__init__(name=name)
            self._internal_linear_1 = hk.nets.MLP(
                output_sizes=output_sizes, name="hk_internal_linear"
            )

        def __call__(self, x):
            return jax.nn.sigmoid(self._internal_linear_1(x))

    def _custom_forward_fn(x):
        module = NetModule()
        return module(x)

    class NetModuleV(hk.Module):
        def __init__(
            self, output_sizes=[32, int(wandb.config.data["d"]),], name="custom_linear",
        ):
            super().__init__(name=name)
            self._internal_linear_1 = hk.nets.MLP(
                output_sizes=output_sizes, name="hk_internal_linear"
            )

        def __call__(self, x):
            return jax.nn.sigmoid(self._internal_linear_1(x))

    def _custom_forward_fn(x):
        module = NetModule()
        return module(x)

    network = hk.without_apply_rng(hk.transform(_custom_forward_fn))

    @jit  # Calculate bivariate normal copula log H_uv and log c_uv
    def norm_copula_logdistribution_logdensity(u, v, rho):
        # clip to prevent 0s and 1s in CDF, needed for numerical stability in high d.
        eps = 1e-6
        u = jnp.clip(u, eps, 1 - eps)
        v = jnp.clip(v, eps, 1 - eps)

        z = (pu - rho * pv) / jnp.sqrt(1 - rho ** 2)
        logcop_dist = norm_logcdf(z)  #! RHS of Eq. 4.5 log Hp(u, v)
        logcop_dist = jnp.clip(logcop_dist, jnp.log(eps), jnp.log(1 - eps))
        logcop_dens = network.apply(
            net_params,
            1 / eps - 1 / jnp.concatenate((jnp.atleast_2d(u), jnp.atleast_2d(v)), 1),
        )

        #! LHS of Eq 4.5 log c(u, v)

        return logcop_dist, logcop_dens
