import jax
import jax.numpy as jnp
from functools import partial
import numpy as np
import optax.assignment


@partial(jax.jit, static_argnames=('num_grid', 'num_quantile'))
def quantile_of_add(quantiles, taus,
                    evaluation_point,
                    num_quantile: int,
                    num_grid: int):
    """
    calculate fast quantile of addition of two quantiles, must be equal shape (num_quantile, ...)
    quantiles: 2-d Array: quantile of X.
    evaluation_point: 1-d Array: quantile of num_crids
    taus: 1-d Array
    """

    def scan_fn(carry, i):
        x_quantile = carry['x_quantile']
        y_quantile = quantiles[i]

        minima = jnp.min(x_quantile)
        maxima = jnp.max(y_quantile)

        grid = jnp.linspace(minima, maxima, num_grid)

        Fx = jnp.interp(grid, x_quantile, taus).clip(0, 1)
        Fy = jnp.interp(grid, y_quantile, taus).clip(0, 1)
        dFy = jnp.diff(Fy, prepend=0.0)

        F_sum = jnp.convolve(Fx, dFy, mode='full')
        dz = grid[1] - grid[0]

        z_sum = grid[0] + jnp.arange(F_sum.size) * dz
        F_sum = jnp.clip(jax.lax.cummax(F_sum), 0.0, 1.0)
        new_x_quantile = jnp.interp(taus, z_sum, F_sum)
        new_carry = { "x_quantile": new_x_quantile }
        return new_carry, new_x_quantile

    init_carry = { "x_quantile": quantiles[0] }
    last_carry, _ = jax.lax.scan(scan_fn, init_carry, jnp.arange(1, num_quantile))
    y_f = jnp.interp(evaluation_point, fp=last_carry['x_quantile'], xp=taus)

    return y_f


@partial(jax.jit, static_argnames=('num_grid', 'num_quantile'))
def quantile_of_add_parallel(quantiles, taus,
                             evaluation_point,
                             num_quantile: int,
                             num_grid: int):
    """
    calculate fast quantile of addition of two quantiles, must be equal shape (num_quantile, ...)
    quantiles: 2-d Array: quantile of X.
    evaluation_point: 1-d Array: quantile of num_crids
    taus: 1-d Array
    """

    x_quantile = quantiles[0]

    y_quantile = quantiles[1:]
    minima = jnp.min(quantiles, axis=-1).sum()
    maxima = jnp.max(quantiles, axis=-1).sum()

    grid = jnp.linspace(minima, maxima, num_grid)

    Fx = jnp.interp(grid, x_quantile, taus).clip(0, 1)
    dFx = jnp.diff(Fx, prepend=0.0)
    dFx = dFx[None]

    F_fft = jnp.concatenate([jnp.fft.rfft(dFx, norm='ortho', axis=-1),
                             jnp.fft.rfft(y_quantile, norm='ortho', axis=-1)], axis=0)

    F_sum = jnp.fft.irfft(jnp.prod(F_fft, axis=0), axis=-1)

    dz = grid[1] - grid[0]
    z_sum = grid[0] + grid[0] + jnp.arange(F_sum.shape[-1]) * dz

    F_sum = jnp.clip(jax.lax.cummax(F_sum), 0.0, 1.0)
    y_f = jnp.interp(evaluation_point, fp=z_sum, xp=F_sum)

    return y_f, z_sum, F_sum


def _gamma_uy(u):
    m = u.shape[0]
    A = np.column_stack([u, np.eye(m)[:, 1:]])
    Q, _ = np.linalg.qr(A)
    return Q[:, 1:]


def quantile_check(x, tau):
    return (tau - (x < 0).astype(x.dtype)) * x


def moqr_loss(W, Y, u_y, a, b_w, b_y, tau=0.5):
    u_y = u_y / np.linalg.norm(u_y)
    Gamma_uy = _gamma_uy(u_y)
    r = (Y @ u_y) - (Y @ Gamma_uy @ b_y) - (W @ b_w) - a
    return quantile_check(r, tau).mean()


def empirical_sinkhorn(C, gamma,
                       p: int = 1, delta: float = 1e-4, max_iter: int = 10):
    """
    :param C: cost matrix
    :param gamma: entropy regularization coefficient
    :param p: p norm for sinkhorn
    :param delta: tolerance
    :param max_iter: maximum number of iteration
    :return: empirical Sinkhorn permutation matrix P
    """
    K = jnp.exp(-C / gamma)
    K_t = K.T
    a = jnp.ones_like(C[:, 0])
    b = jnp.ones_like(C[0, :])
    a = a/a.shape[0]
    b = b/b.shape[0]

    def step(carry):
        k = carry['k']
        u_k = carry['u']
        v_k = carry['v']
        u_kp1 = jnp.where(k % 2 == 0, a / (K @ v_k), u_k)
        v_kp1 = jnp.where(k % 2 == 0, v_k, b / (K_t @ u_k))
        # P_k = jnp.linalg.diagonal(u_kp1) @ K @jnp.linalg.diagonal(v_kp1)
        P_k = jnp.einsum('i, ij, j -> ij', u_kp1, K, v_kp1)
        err = jnp.linalg.norm(P_k @ jnp.ones_like(a) - a, ord=p) + jnp.linalg.norm(P_k.T @ jnp.ones_like(b) - b, ord=p)
        new_carry = { "k": k + 1, "u": u_kp1, "v": v_kp1, "err": err, "P": P_k }
        return new_carry

    def cond_fn(carry):
        err = carry['err']
        end_of_ = (carry['k'] < max_iter) & (err > delta)
        return end_of_
    initial_carry = {"err": jnp.inf, "k": 0, "u": jnp.ones_like(a), "v": jnp.ones_like(b), "P": jnp.ones_like(C)}
    result = jax.lax.while_loop(
        cond_fn,
        step,
        initial_carry
    )
    return result


@partial(jax.jit, static_argnames=('p', 'delta', 'max_iter'))
def stable_sinkhorn(C, gamma, p: int = 1, delta: float = 1e-4, max_iter: int = 10):
    """
    :param C: cost matrix
    :param gamma: entropy regularization coefficient
    :param p: p norm for sinkhorn
    :param delta: tolerance
    :param max_iter: maximum number of iteration
    :return: empirical Sinkhorn permutation matrix P
    """
    log_K = (-C / gamma)
    log_K_t = log_K.T
    a = jnp.ones_like(C[:, 0])
    b = jnp.ones_like(C[0, :])
    a = a/a.shape[0]
    b = b/b.shape[0]
    a = jnp.log(a)
    b = jnp.log(b)

    log_delta = jnp.log(delta)
    def log_matrix_mul(mat, vec):
        """
        calculate exp(M) @ exp(vec)
        :param mat: elementwise log matrix
        :param vec: elementwise log vector
        :return: log(exp(M) @ exp(vec)). Here, exp is elementwise exp.
        """
        # [i, j] + [j, i]
        add = mat + vec[None, :]
        return jax.nn.logsumexp(add, axis=-1)

    def log_sub(left, right):
        # exp(a) - exp(b) = exp(a) + exp(log(1j) + a)
        # log (A - B) = logsumexp(cat(a, log1j + log_b)
        return jnp.maximum(left, right) + jnp.log1p(-jnp.abs(left - right))

    def log_norm(vec):
        """
        :param vec: log vector
        :param p: order
        :return: logarithm of p-norm of exp(vec)
        """
        # calculate log(1/p sum[exp(a_1)^p + exp(a_2)^p + ... ^p]))
        return 1/p * jax.nn.logsumexp(vec * p)


    def step(carry):
        k = carry['k']
        u_k = carry['u']
        v_k = carry['v']
        u_kp1 = jnp.where(k % 2 == 0, a - log_matrix_mul(log_K, v_k), u_k)
        v_kp1 = jnp.where(k % 2 == 0, v_k, b - log_matrix_mul(log_K_t, u_k))
        # P_k = jnp.linalg.diagonal(u_kp1) @ K @jnp.linalg.diagonal(v_kp1)
        log_P_k = u_kp1[:, None] + log_K + v_kp1[None, :]
        log_err = jax.nn.logsumexp(jnp.asarray([log_norm(log_sub(log_matrix_mul(log_P_k, jnp.zeros_like(b)), a)),
               log_norm(log_sub(log_matrix_mul(log_P_k.T, jnp.zeros_like(a)),b))]))
        new_carry = { "k": k + 1, "u": u_kp1, "v": v_kp1, "log_err": log_err, "log_P": log_P_k }
        return new_carry

    def cond_fn(carry):
        log_err = carry['log_err']
        end_of_ = (carry['k'] < max_iter) & (log_err > log_delta)
        return end_of_
    initial_carry = {"log_err": jnp.inf, "k": 0, "u": jnp.zeros_like(a), "v": jnp.zeros_like(b), "log_P": jnp.zeros_like(C)}
    result = jax.lax.while_loop(
        cond_fn,
        step,
        initial_carry
    )

    return result


def sinkhorn_matching(C, gamma, p: int = 1, delta: float = 1e-4, max_iter: int = 10):
    out = stable_sinkhorn(C, gamma, p, delta, max_iter)
    logP = out["log_P"]
    n, m = logP.shape

    used_r = jnp.zeros((n,), dtype=jnp.bool_)
    used_c = jnp.zeros((m,), dtype=jnp.bool_)
    I = jnp.full((min(n, m),), -1, dtype=jnp.int32)
    J = jnp.full((min(n, m),), -1, dtype=jnp.int32)
    k0 = jnp.array(0, dtype=jnp.int32)

    def body(state):
        used_r, used_c, I, J, k = state
        mask = (~used_r)[:, None] & (~used_c)[None, :]
        scores = jnp.where(mask, logP, -jnp.inf)

        j1 = jnp.argmax(scores, axis=1)
        v1 = jnp.max(scores, axis=1)
        scores2 = scores.at[jnp.arange(n), j1].set(-jnp.inf)
        v2 = jnp.max(scores2, axis=1)
        margin = jnp.where(~used_r, v1 - v2, -jnp.inf)

        i_pick = jnp.argmax(margin)
        j_pick = j1[i_pick]

        used_r = used_r.at[i_pick].set(True)
        used_c = used_c.at[j_pick].set(True)
        I = I.at[k].set(i_pick.astype(jnp.int32))
        J = J.at[k].set(j_pick.astype(jnp.int32))
        return (used_r, used_c, I, J, k + 1)

    def cond(state):
        used_r, used_c, I, J, k = state
        return jnp.logical_and(jnp.any(~used_r), jnp.any(~used_c))

    used_r, used_c, I, J, _ = jax.lax.while_loop(cond, body, (used_r, used_c, I, J, k0))
    return I, J

if __name__ == '__main__':
    import ot
    import numpy as np
    key = jax.random.PRNGKey(42)
    a = 3 * jax.random.normal(key, shape=(4, 2))
    b = 5 * jax.random.beta(key, shape=(4, 2), a=0.5, b=0.5)
    C = jnp.linalg.norm(a[:, None] - b[None, :], axis=-1, ord=1)
    i, j = sinkhorn_matching(C, gamma=0.001, p=1, max_iter=10000000, delta=1e-6)
    print("sinkhorn", j)
    sinkhorn_result = stable_sinkhorn(C,gamma=0.01, p=1, max_iter=10000000, delta=1e-6)
    print(jnp.exp(sinkhorn_result['log_P']).sum(axis=0))
    print(jnp.exp(sinkhorn_result['log_P']).sum(axis=1))

    i = jnp.argsort(jax.nn.logsumexp(sinkhorn_result['log_P'], axis=0))
    j = jnp.argsort(jax.nn.logsumexp(sinkhorn_result['log_P'], axis=1))

    i_h, j_h = optax.assignment.hungarian_algorithm(C)
    print(j)
    print(j_h)
    print(jnp.linalg.norm(a - b, ord=1, axis=-1).sum())
    print(jnp.linalg.norm(a[i] - b[j], ord=1, axis=-1).sum())
    print(C[i, j].sum())

    print(jnp.linalg.norm(a[i_h] - b[j_h], ord=1, axis=-1).sum())
    print(C[i_h, j_h].sum())








