import jax
import functools
import numpy as np
import jax.numpy as jnp


@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
def logsumexp(a, axis=None):
    return _logsumexp_fwd(a, axis)[0]

def _logsumexp_fwd(a, axis):
    c = jnp.max(a, axis=axis, keepdims=True)
    safe = jnp.isfinite(c)
    c = jnp.where(safe, c, 0)
    e = jnp.exp(a - c)
    z = jnp.sum(e, axis=axis, keepdims=True)
    r = jnp.squeeze(c, axis=axis) + jnp.log(jnp.squeeze(z, axis=axis))
    return r, (e, z)

def _logsumexp_bwd(axis, res, g):
    e, z = res
    safe = z != 0
    z = jnp.where(safe, z, 1)
    if axis is not None:
        g = jnp.expand_dims(g, axis=axis)
    return (g / z * e,)

logsumexp.defvjp(_logsumexp_fwd, _logsumexp_bwd)

@jax.jit
def logp(log_a, mu, log_s, uL_inv_, x):
    uL_inv = jax.lax.transpose(uL_inv_, (0, 2, 1))
    return \
        jnp.sum(
          logsumexp(
            jnp.expand_dims(jax.nn.log_softmax(log_a), 0) + (
              - 0.5 * x.shape[-1] * jax.lax.log(2.0 * jnp.pi) -
              jnp.sum(jnp.expand_dims(log_s, 0) +
                0.5 * jax.lax.square(
                  jax.lax.dot_general(
                    jnp.tile(jnp.expand_dims(uL_inv, 0), (x.shape[0], 1, 1, 1)),
                    jnp.expand_dims(x, 1) - jnp.expand_dims(mu, 0), (((3,), (2,)), ((0, 1),(0, 1))))) *
                jax.lax.exp(jnp.expand_dims(-2.0 * log_s, 0)), -1
              )
            ), -1
          )
        )

@jax.jit
def logp_vjp(log_a, mu, log_s, uL_inv, x):
    return jax.vjp(logp, log_a, mu, log_s, uL_inv, x)

def logp_adj(log_a, mu, log_s, uL_inv, x):
    dim = x.shape[-1]
    loss, vjp_fn = logp_vjp(log_a, mu, log_s, uL_inv, x)
    def adj_fn(x):
        adj_log_a, adj_mu, adj_log_s, adj_uL_inv, adj_x = vjp_fn(x)
        del adj_x
        mask = jnp.expand_dims(jnp.triu(jnp.ones((dim,dim)), 1), 0)
        return np.array(adj_log_a), np.array(adj_mu).T, \
            np.array(adj_log_s).T, np.array(mask * adj_uL_inv).T, None
    return loss.tolist(), adj_fn


def logp_singular(log_a, mu, log_s, uL_inv, x):
    return \
        logsumexp(
          jax.nn.log_softmax(log_a) + (
            - 0.5 * x.shape[-1] * jax.lax.log(2.0 * jnp.pi) -
            jnp.sum(
              log_s +
              0.5 * jax.lax.square(
                jax.lax.dot_general(
                  uL_inv,
                  jnp.expand_dims(x, 0) - mu, (((2,), (1,)), ((0,), (0,))))) *
              jax.lax.exp(-2.0 * log_s), -1
            )
          )
        )

def vi(log_a1, log_a2, mu1, mu2, log_s1, log_s2, uL1_, uL1_inv_, uL2_inv_, x):
    uL1 = jax.lax.transpose(uL1_, (0, 2, 1))
    uL1_inv = jax.lax.transpose(uL1_inv_, (0, 2, 1))
    uL2_inv = jax.lax.transpose(uL2_inv_, (0, 2, 1))

    L1 = jnp.expand_dims(jax.lax.exp(log_s1), -2) * uL1
    S1 = jax.lax.dot_general(L1, L1, (((2,), (2,)), ((0,), (0,))))

    def logp_diff(x):
        return logp_singular(log_a1, mu1, log_s1, uL1_inv, x) - \
               logp_singular(log_a2, mu2, log_s2, uL2_inv, x)

    nb, nmix, dim = x.shape
    x_merge = jnp.reshape(x, (-1, dim))

    value_merge, grad_merge = jax.vmap(jax.value_and_grad(logp_diff))(x_merge)
    hess_merge = jax.vmap(jax.hessian(logp_diff))(x_merge)

    value = jnp.mean(jnp.reshape(value_merge, (nb, nmix)), 0)
    grad =  jnp.mean(jnp.reshape(grad_merge, (nb, nmix, dim)), 0)
    hess =  jnp.mean(jnp.reshape(hess_merge, (nb, nmix, dim, dim)), 0)

    return \
        jnp.mean(jax.lax.stop_gradient(
            2 * (value - jnp.sum(jax.nn.softmax(log_a1) * value)) * jax.nn.softmax(log_a1)) * jax.nn.softmax(log_a1)) + \
        jnp.mean(jax.lax.stop_gradient(grad) * mu1) * dim + \
        jnp.mean(jax.lax.stop_gradient(
          jax.lax.dot_general(hess, S1, (((2,),(1,)), ((0,),(0,)))) +
          jax.lax.dot_general(hess, S1, (((1,),(2,)), ((0,),(0,))))) * S1) * jax.lax.square(dim)

@jax.jit
def vi_vjp(log_a1, log_a2, mu1, mu2, log_s1, log_s2, uL1_, uL1_inv_, uL2_inv_, x):
    return jax.vjp(vi, log_a1, log_a2, mu1, mu2, log_s1, log_s2, uL1_, uL1_inv_, uL2_inv_, x)

def vi_adj(log_a1, log_a2, mu1, mu2, log_s1, log_s2, uL1_, uL1_inv_, uL2_inv_, x):
    dim = mu1.shape[-1]
    loss, vjp_fn = vi_vjp(log_a1, log_a2, mu1, mu2, log_s1, log_s2, uL1_, uL1_inv_, uL2_inv_, x)
    def adj_fn(x):
        adj_log_a1, adj_log_a2, adj_mu1, adj_mu2, adj_log_s1, adj_log_s2, adj_uL1_, adj_uL1_inv_, adj_uL2_inv_, adj_x = vjp_fn(x)
        del adj_log_a2
        del adj_mu2
        del adj_log_s2
        del adj_uL2_inv_
        del adj_x

        mask = jnp.expand_dims(jnp.triu(jnp.ones((dim, dim)), 1), 0)
        return np.array(adj_log_a1), None, np.array(adj_mu1).T, None, \
            np.array(adj_log_s1).T, None, np.array(mask * adj_uL1_).T, np.array(mask * adj_uL1_inv_).T, None, None
    return loss.tolist(), adj_fn
