import numpy as np
import jax
import jax.numpy as jnp

def inverse_parameterization(parameterization, lambdas):
    if parameterization == "linear":
        weights = -jnp.log(lambdas)
    elif parameterization == "relu":
        weights = -jnp.log(lambdas)
    elif parameterization == "exp":
        weights = jnp.log(-jnp.log(lambdas))
    elif parameterization == "softplus":
        weights = -jax.scipy.special.logit(lambdas)
    elif parameterization == "tanh":
        weights = jnp.arctanh(lambdas)
    elif parameterization == "best":
        weights = (1/(1-lambdas) - 0.50001) ** 0.5
    else:
        raise NotImplementedError("Parameterization not implemented")
    return weights


def compute_lambda(parameterization, Lambda_weights, ):
    if parameterization == "linear":
        return jnp.exp(-Lambda_weights)
    elif parameterization == "relu":
        return jnp.exp(-jax.nn.relu(Lambda_weights))
    elif parameterization == "exp":
        return jnp.exp(-jnp.exp(Lambda_weights))
    elif parameterization == "softplus":
        return 1-jax.nn.sigmoid(Lambda_weights)
    elif parameterization == "tanh":
        return jnp.tanh(Lambda_weights)
    elif parameterization == "best":
        return 1 - 1 / (Lambda_weights**2 + 0.50001)
    else:
        raise NotImplementedError("Parameterization {} not implemented".format(parameterization))
