import jax
import functools
import jax.numpy as jnp


@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
def logsumexp(a, axis=None):
    return _logsumexp_fwd(a, axis)[0]

def _logsumexp_fwd(a, axis):
    c = jnp.max(a, axis=axis, keepdims=True)
    safe = jnp.isfinite(c)
    c = jnp.where(safe, c, 0)
    e = jnp.exp(a - c)
    z = jnp.sum(e, axis=axis, keepdims=True)
    r = jnp.squeeze(c, axis=axis) + jnp.log(jnp.squeeze(z, axis=axis))
    return r, (e, z)

def _logsumexp_bwd(axis, res, g):
    e, z = res
    safe = z != 0
    z = jnp.where(safe, z, 1)
    if axis is not None:
        g = jnp.expand_dims(g, axis=axis)
    return (g / z * e,)

logsumexp.defvjp(_logsumexp_fwd, _logsumexp_bwd)

def logp_diag(eps, log_a, mu, s, log_s, x):
    return \
        logsumexp(
          jax.nn.log_softmax(log_a) + 0.5 * (
            - x.shape[-1] * jax.lax.log(2.0 * jnp.pi) -
            jnp.sum(
              jax.lax.log(eps) + log_s +
              jax.lax.square(jnp.expand_dims(x, 0) - mu) / (eps * s), -1
            )
          )
        )

def logp_full(eps, log_a, mu, S_inv, log_s, x):
    diff = jnp.expand_dims(x, 0) - mu
    return \
        logsumexp(
          jax.nn.log_softmax(log_a) + 0.5 * (
            - x.shape[-1] * jax.lax.log(2.0 * jnp.pi) -
            jnp.sum(
              jax.lax.log(eps) + log_s +
              diff * jnp.sum(S_inv * jnp.expand_dims(diff, -2), -1), -1
            )
          )
        )

def vi0_diag(eps, ps1, ps2, y):
    r1, log_alpha1 = ps1["r"], ps1["log_alpha_raw"] / eps
    r2, log_alpha2 = ps2["r"], ps2["log_alpha_raw"] / eps
    log_s1 = ps1["S_log_diagonal_matrix"]
    log_s2 = ps2["S_log_diagonal_matrix"]
    s1 = jax.lax.exp(log_s1)
    s2 = jax.lax.exp(log_s2)

    def logp_diff(y):
        return logp_diag(eps, log_alpha1, r1, s1, log_s1, y) - \
               logp_diag(eps, log_alpha2, r2, s2, log_s2, y)

    nb, nmix, dim = y.shape
    y_merge = jnp.reshape(y, (-1, dim))

    value_merge, grad_merge = jax.vmap(jax.value_and_grad(logp_diff))(y_merge)
    hess_merge = jax.vmap(jax.hessian(logp_diff))(y_merge)

    value = jnp.mean(jnp.reshape(value_merge, (nb, nmix)), 0)
    grad =  jnp.mean(jnp.reshape(grad_merge, (nb, nmix, dim)), 0)
    hess =  jnp.mean(jnp.reshape(hess_merge, (nb, nmix, dim, dim)), 0)

    return \
        jnp.mean(jax.lax.stop_gradient(
            2 * (value - jnp.sum(jax.nn.softmax(log_alpha1) * value)) * jax.nn.softmax(log_alpha1)) * jax.nn.softmax(log_alpha1)) + \
        jnp.mean(jax.lax.stop_gradient(grad) * r1) * dim + \
        jnp.mean(eps * jax.lax.stop_gradient(
            (hess * jnp.expand_dims(s1, -2)) +
            (hess * jnp.expand_dims(s1, -1))) * jnp.expand_dims(eps * s1, -2)) * jax.lax.square(dim)

def vi0_full(eps, ps1, ps2, y):
    r1, log_alpha1 = ps1["r"], ps1["log_alpha_raw"] / eps
    r2, log_alpha2 = ps2["r"], ps2["log_alpha_raw"] / eps
    log_s1 = ps1["S_log_diagonal_matrix"]
    log_s2 = ps2["S_log_diagonal_matrix"]
    S1_rot = ps1["S_rotation_matrix"]
    S2_rot = ps2["S_rotation_matrix"]

    S1 = jax.lax.dot_general(S1_rot * jnp.expand_dims(jax.lax.exp(log_s1), 1), S1_rot, (((2,), (2,)), ((0,), (0,))))
    S2 = jax.lax.dot_general(S2_rot * jnp.expand_dims(jax.lax.exp(log_s2), 1), S2_rot, (((2,), (2,)), ((0,), (0,))))
    S1_inv = jax.lax.dot_general(S1_rot * jnp.expand_dims(jax.lax.exp(-log_s1), 1), S1_rot, (((2,), (2,)), ((0,), (0,))))
    S2_inv = jax.lax.dot_general(S2_rot * jnp.expand_dims(jax.lax.exp(-log_s2), 1), S2_rot, (((2,), (2,)), ((0,), (0,))))

    def logp_diff(y):
        return logp_full(eps, log_alpha1, r1, S1_inv, log_s1, y) - \
                logp_full(eps, log_alpha2, r2, S2_inv, log_s2, y)

    nb, nmix, dim = y.shape
    y_merge = jnp.reshape(y, (-1, dim))

    value_merge, grad_merge = jax.vmap(jax.value_and_grad(logp_diff))(y_merge)
    hess_merge = jax.vmap(jax.hessian(logp_diff))(y_merge)

    value = jnp.mean(jnp.reshape(value_merge, (nb, nmix)), 0)
    grad =  jnp.mean(jnp.reshape(grad_merge, (nb, nmix, dim)), 0)
    hess =  jnp.mean(jnp.reshape(hess_merge, (nb, nmix, dim, dim)), 0)

    return \
        jnp.mean(jax.lax.stop_gradient(
            2 * (value - jnp.sum(jax.nn.softmax(log_alpha1) * value)) * jax.nn.softmax(log_alpha1)) * jax.nn.softmax(log_alpha1)) + \
        jnp.mean(jax.lax.stop_gradient(grad) * r1) * dim + \
        jnp.mean(eps * jax.lax.stop_gradient(
          jax.lax.dot_general(hess, S1, (((2,),(1,)), ((0,),(0,)))) +
          jax.lax.dot_general(hess, S1, (((1,),(2,)), ((0,),(0,))))) * (eps * S1)) * jax.lax.square(dim)

@functools.partial(jax.jit, static_argnums=0)
def vi0(is_diag, eps, vs1, vs2, y):
    if is_diag:
        return vi0_diag(eps, vs1["params"], vs2["params"], y)
    else:
        return vi0_full(eps, vs1["params"], vs2["params"], y)


def vi_diag(eps, ps1, ps2, x, y):
    # eps = ps1["eps"]
    r1, log_alpha1 = ps1["r"], ps1["log_alpha_raw"] / eps
    r2, log_alpha2 = ps2["r"], ps2["log_alpha_raw"] / eps
    log_s1 = ps1["S_log_diagonal_matrix"]
    log_s2 = ps2["S_log_diagonal_matrix"]

    # jax.debug.print("{x}", x=x)
    x_ = jnp.expand_dims(x, 1) 
    # x_:       nx x 1 x dim
    # _log_s1:  1 x np x dim
    # _r1    :  1 x np x dim
    _r1 = jnp.expand_dims(r1, 0) 
    _r2 = jnp.expand_dims(r2, 0)
    _log_s1 = jnp.expand_dims(log_s1, 0)
    _log_s2 = jnp.expand_dims(log_s2, 0)
    s1 = jax.lax.exp(log_s1)
    s2 = jax.lax.exp(log_s2)
    _s1 = jnp.expand_dims(s1, 0)
    _s2 = jnp.expand_dims(s2, 0)

    s1_x = _s1 * x_
    s2_x = _s2 * x_
    mu1 = _r1 + s1_x
    mu2 = _r2 + s2_x
    log_a1 = (jnp.sum(x_ * s1_x, -1) + 2 * jnp.sum(x_ * _r1, -1)) / (2 * eps) + jnp.expand_dims(log_alpha1, 0)
    log_a2 = (jnp.sum(x_ * s2_x, -1) + 2 * jnp.sum(x_ * _r2, -1)) / (2 * eps) + jnp.expand_dims(log_alpha2, 0)

    # mu1: nx * np * dim
    # def vi(log_a1, log_a2, mu1, mu2, s1, s2, log_s1, log_s2, y):
   
    def vi(log_a1, log_a2, mu1, mu2, y):
        def logp_diff(y):
            return logp_diag(eps, log_a1, mu1, s1, log_s1, y) - \
                   logp_diag(eps, log_a2, mu2, s2, log_s2, y)

        nb, nmix, dim = y.shape
        # y: ny x np x dim
        y_merge = jnp.reshape(y, (-1, dim))
        # y_merge : (ny x np) * dim

        value_merge, grad_merge = jax.vmap(jax.value_and_grad(logp_diff))(y_merge)
        hess_merge = jax.vmap(jax.hessian(logp_diff))(y_merge)

        value = jnp.mean(jnp.reshape(value_merge, (nb, nmix)), 0)
        grad =  jnp.mean(jnp.reshape(grad_merge, (nb, nmix, dim)), 0)
        hess =  jnp.mean(jnp.reshape(hess_merge, (nb, nmix, dim, dim)), 0)

        return \
            jnp.mean(jax.lax.stop_gradient(
              2 * (value - jnp.sum(jax.nn.softmax(log_a1) * value)) * jax.nn.softmax(log_a1)) * jax.nn.softmax(log_a1)) + \
            jnp.mean(jax.lax.stop_gradient(grad) * mu1) * dim + \
            jnp.mean(eps * jax.lax.stop_gradient(
              (hess * jnp.expand_dims(s1, -2)) +
              (hess * jnp.expand_dims(s1, -1))) * jnp.expand_dims(eps * s1, -2)) * jax.lax.square(dim)

    return jax.vmap(vi)(log_a1, log_a2, mu1, mu2, y)

def vi_full(eps, ps1, ps2, x, y):
    # eps = ps1["eps"]
    r1, log_alpha1 = ps1["r"], ps1["log_alpha_raw"] / eps
    r2, log_alpha2 = ps2["r"], ps2["log_alpha_raw"] / eps
    log_s1 = ps1["S_log_diagonal_matrix"]
    log_s2 = ps2["S_log_diagonal_matrix"]
    S1_rot = ps1["S_rotation_matrix"]
    S2_rot = ps2["S_rotation_matrix"]

    S1 = jax.lax.dot_general(S1_rot * jnp.expand_dims(jax.lax.exp(log_s1), 1), S1_rot, (((2,), (2,)), ((0,), (0,))))
    S2 = jax.lax.dot_general(S2_rot * jnp.expand_dims(jax.lax.exp(log_s2), 1), S2_rot, (((2,), (2,)), ((0,), (0,))))
    S1_inv = jax.lax.dot_general(S1_rot * jnp.expand_dims(jax.lax.exp(-log_s1), 1), S1_rot, (((2,), (2,)), ((0,), (0,))))
    S2_inv = jax.lax.dot_general(S2_rot * jnp.expand_dims(jax.lax.exp(-log_s2), 1), S2_rot, (((2,), (2,)), ((0,), (0,))))

    x_ = jnp.expand_dims(x, 1)
    _r1 = jnp.expand_dims(r1, 0)
    _r2 = jnp.expand_dims(r2, 0)
    _S1 = jnp.expand_dims(S1, 0)
    _S2 = jnp.expand_dims(S2, 0)
    _S1_inv = jnp.expand_dims(S1_inv, 0)
    _S2_inv = jnp.expand_dims(S2_inv, 0)

    _log_alpha1 = jnp.expand_dims(log_alpha1, 0)
    _log_alpha2 = jnp.expand_dims(log_alpha2, 0)

    x_tiled = jnp.tile(x_, (1, log_s1.shape[0], 1))
    S1_tiled = jnp.tile(_S1, (x.shape[0], 1, 1, 1))
    S2_tiled = jnp.tile(_S2, (x.shape[0], 1, 1, 1))
    # S1_tiled_y = jnp.tile(_S1, (y.shape[1], 1, 1, 1))
    # S2_tiled_y = jnp.tile(_S2, (y.shape[1], 1, 1, 1))

    S1_x = jax.lax.dot_general(S1_tiled, x_tiled, (((3,), (2,)), ((0, 1), (0, 1))))
    S2_x = jax.lax.dot_general(S2_tiled, x_tiled, (((3,), (2,)), ((0, 1), (0, 1))))

    mu1 = _r1 + S1_x
    mu2 = _r2 + S2_x

    _log_s1 = jnp.expand_dims(log_s1, 0)
    _log_s2 = jnp.expand_dims(log_s2, 0)

    log_a1 = (jnp.sum(x_tiled * S1_x, -1) + 2 * jnp.sum(x_ * _r1, -1)) / (2 * eps) + _log_alpha1
    log_a2 = (jnp.sum(x_tiled * S2_x, -1) + 2 * jnp.sum(x_ * _r2, -1)) / (2 * eps) + _log_alpha2

    def vi(log_a1, log_a2, mu1, mu2, y):
        def logp_diff(y):
            return logp_full(eps, log_a1, mu1, S1_inv, log_s1, y) - \
                   logp_full(eps, log_a2, mu2, S2_inv, log_s2, y)

        nb, nmix, dim = y.shape
        y_merge = jnp.reshape(y, (-1, dim))

        value_merge, grad_merge = jax.vmap(jax.value_and_grad(logp_diff))(y_merge)
        hess_merge = jax.vmap(jax.hessian(logp_diff))(y_merge)

        value = jnp.mean(jnp.reshape(value_merge, (nb, nmix)), 0)
        grad =  jnp.mean(jnp.reshape(grad_merge, (nb, nmix, dim)), 0)
        hess =  jnp.mean(jnp.reshape(hess_merge, (nb, nmix, dim, dim)), 0)

        return \
            jnp.mean(jax.lax.stop_gradient(
              2 * (value - jnp.sum(jax.nn.softmax(log_a1) * value)) * jax.nn.softmax(log_a1)) * jax.nn.softmax(log_a1)) + \
            jnp.mean(jax.lax.stop_gradient(grad) * mu1) * dim + \
            jnp.mean(eps * jax.lax.stop_gradient(
              jax.lax.dot_general(hess, S1, (((2,),(1,)), ((0,),(0,)))) +
              jax.lax.dot_general(hess, S1, (((1,),(2,)), ((0,),(0,))))) * (eps * S1)) * jax.lax.square(dim)

    return jax.vmap(vi)(log_a1, log_a2, mu1, mu2, y)


@functools.partial(jax.jit, static_argnums=0)
def vi(is_diag, eps, vs1, vs2, x, y):
    if is_diag:
        return vi_diag(eps, vs1["params"], vs2["params"], x, y)
    else:
        return vi_full(eps, vs1["params"], vs2["params"], x, y)
