import numpy as np
import time
import scipy
from sklearn.cluster import KMeans


def LR_Dykstra_Sin(K1, K2, K3, a, b, alpha, max_iter=1000, delta=1e-9, lam=0):
    Q = K1
    R = K2
    g_old = K3

    r = np.shape(K3)[0]
    v1_old, v2_old = np.ones(r), np.ones(r)
    u1, u2 = np.ones(np.shape(a)[0]), np.ones(np.shape(b)[0])

    q_gi, q_gp = np.ones(r), np.ones(r)
    q_Q, q_R = np.ones(r), np.ones(r)

    err = 1
    n_iter = 0
    while n_iter < max_iter:
        u1_prev, v1_prev = u1, v1_old
        u2_prev, v2_prev = u2, v2_old
        g_prev = g_old
        if err > delta:
            n_iter = n_iter + 1

            # First Projection
            u1 = a / (np.dot(K1, v1_old) + lam)
            u2 = b / (np.dot(K2, v2_old) + lam)
            g = np.maximum(alpha, g_old * q_gi)
            q_gi = (g_old * q_gi) / (g + lam)
            g_old = g.copy()

            # Second Projection
            v1_trans = np.dot(K1.T, u1)
            v2_trans = np.dot(K2.T, u2)
            g = (g_old * q_gp * v1_old * q_Q * v1_trans * v2_old * q_R * v2_trans) ** (
                1 / 3
            )
            v1 = g / (v1_trans + lam)
            v2 = g / (v2_trans + lam)
            q_gp = (g_old * q_gp) / (g + lam)
            q_Q = (q_Q * v1_old) / (v1 + lam)
            q_R = (q_R * v2_old) / (v2 + lam)
            v1_old = v1.copy()
            v2_old = v2.copy()
            g_old = g.copy()

            # Update the error
            u1_trans = np.dot(K1, v1)
            err_1 = np.sum(np.abs(u1 * u1_trans - a))
            u2_trans = np.dot(K2, v2)
            err_2 = np.sum(np.abs(u2 * u2_trans - b))
            err = err_1 + err_2

            if (
                np.any(np.isnan(u1))
                or np.any(np.isnan(v1))
                or np.any(np.isnan(u2))
                or np.any(np.isnan(v2))
                or np.any(np.isinf(u1))
                or np.any(np.isinf(v1))
                or np.any(np.isinf(u2))
                or np.any(np.isinf(v2))
            ):
                # we have reached the machine precision
                # come back to previous solution and quit loop
                print("Warning: numerical errors at iteration", n_iter)
                u1, v1 = u1_prev, v1_prev
                u2, v2 = u2_prev, v2_prev
                g = g_prev
                break
        else:
            Q = u1.reshape((-1, 1)) * K1 * v1.reshape((1, -1))
            R = u2.reshape((-1, 1)) * K2 * v2.reshape((1, -1))
            n, m = np.shape(K1)[0], np.shape(K2)[0]
            count_op = (
                (n_iter + 1) * (20 * r + 2 * n * r + 2 * m * r + n + m)
                + 2 * n * r
                + 2 * m * r
            )
            return Q, R, g, count_op

    Q = u1.reshape((-1, 1)) * K1 * v1.reshape((1, -1))
    R = u2.reshape((-1, 1)) * K2 * v2.reshape((1, -1))
    n, m = np.shape(K1)[0], np.shape(K2)[0]
    count_op = (
        (n_iter + 1) * (20 * r + 2 * n * r + 2 * m * r + n + m) + 2 * n * r + 2 * m * r
    )
    return Q, R, g, count_op


def LR_Dykstra_LSE_Sin(
    C1, C2, C3, a, b, alpha, gamma, max_iter=1000, delta=1e-9, lam=0
):

    h_old = C3
    r = np.shape(C3)[0]
    g1_old, g2_old = np.zeros(r), np.zeros(r)
    f1, f2 = np.zeros(np.shape(a)[0]), np.zeros(np.shape(b)[0])

    w_gi, w_gp = np.zeros(r), np.zeros(
        r
    )  # q_gi, q_gp = np.exp(gamma * w_gi), np.exp(gamma * w_gp)
    w_Q, w_R = np.zeros(r), np.zeros(
        r
    )  # q_Q, q_R = np.exp(gamma * w_Q), np.exp(gamma * w_R)

    err = 1
    n_iter = 0
    while n_iter < max_iter:
        f1_prev, g1_prev = f1, g1_old
        f2_prev, g2_prev = f2, g2_old
        h_prev = h_old
        if err > delta:
            n_iter = n_iter + 1

            # First Projection
            C1_tilde = f1[:, None] + g1_old[None, :] - C1  # 2 * n * r
            C1_tilde = C1_tilde * gamma  # n * r
            f1 = (
                (1 / gamma) * np.log(a)
                + f1
                - (1 / gamma) * scipy.special.logsumexp(C1_tilde, axis=1)
            )  # 2 * n + 2 * n + n * r

            C2_tilde = f2[:, None] + g2_old[None, :] - C2  # 2 * m * r
            C2_tilde = C2_tilde * gamma  # m * r
            f2 = (
                (1 / gamma) * np.log(b)
                + f2
                - (1 / gamma) * scipy.special.logsumexp(C2_tilde, axis=1)
            )  # 2 * m + 2 * m + m * r

            h = h_old + w_gi  # 2 * r
            h = np.maximum((np.log(alpha) / gamma), h)  # r
            w_gi = h_old + w_gi - h  # 2 * r
            h_old = h.copy()

            # Update couplings
            C1_tilde = f1[:, None] + g1_old[None, :] - C1  # 2 * n * r
            C1_tilde = C1_tilde * gamma  # n * r
            alpha_1_trans = scipy.special.logsumexp(C1_tilde, axis=0)  # n * r

            C2_tilde = f2[:, None] + g2_old[None, :] - C2  # 2 * m * r
            C2_tilde = C2_tilde * gamma  # m * r
            alpha_2_trans = scipy.special.logsumexp(C2_tilde, axis=0)  # m * r

            # Second Projection
            h = (1 / 3) * (h_old + w_gp + w_Q + w_R)  # 4 * r
            h = h + (1 / (3 * gamma)) * alpha_1_trans  # 2 * r
            h = h + (1 / (3 * gamma)) * alpha_2_trans  # 2 * r
            g1 = h + g1_old - (1 / gamma) * alpha_1_trans  # 3 * r
            g2 = h + g2_old - (1 / gamma) * alpha_2_trans  # 3 * r

            w_Q = w_Q + g1_old - g1  # 2 * r
            w_R = w_R + g2_old - g2  # 2 * r
            w_gp = h_old + w_gp - h  # 2 * r

            g1_old = g1.copy()
            g2_old = g2.copy()
            h_old = h.copy()

            # Update couplings
            C1_tilde = f1[:, None] + g1_old[None, :] - C1  # 2 * n * r
            C1_tilde = C1_tilde * gamma  # n * r
            Q = np.exp(C1_tilde)  # n * r

            C2_tilde = f2[:, None] + g2_old[None, :] - C2  # 2 * n * r
            C2_tilde = C2_tilde * gamma  # n * r
            R = np.exp(C2_tilde)  # n * r

            g = np.exp(gamma * h)  # 2 * r

            # Update the error
            err_1 = np.sum(np.abs(np.sum(Q, axis=1) - a))
            err_2 = np.sum(np.abs(np.sum(R, axis=1) - b))
            err = err_1 + err_2

            if (
                np.any(np.isnan(f1))
                or np.any(np.isnan(g1))
                or np.any(np.isnan(f2))
                or np.any(np.isnan(g2))
                or np.any(np.isinf(f1))
                or np.any(np.isinf(g1))
                or np.any(np.isinf(f2))
                or np.any(np.isinf(g2))
            ):
                # we have reached the machine precision
                # come back to previous solution and quit loop
                print("Warning: numerical errors at iteration", n_iter)
                f1, g1 = f1_prev, g1_prev
                f2, g2 = f2_prev, g2_prev
                h = h_prev

                # Update couplings
                C1_tilde = f1[:, None] + g1_old[None, :] - C1
                C1_tilde = C1_tilde * gamma
                Q = np.exp(C1_tilde)

                C2_tilde = f2[:, None] + g2_old[None, :] - C2
                C2_tilde = C2_tilde * gamma
                R = np.exp(C2_tilde)

                g = np.exp(gamma * h)

                n, m = np.shape(C1)[0], np.shape(C2)[0]
                count_op = (
                    (n_iter) * (8 * n * r + 8 * m * r + 4 * n + 4 * m + 27 * r)
                    + 4 * n * r
                    + 4 * m * r
                )
                return Q, R, g, count_op

        else:
            n, m = np.shape(C1)[0], np.shape(C2)[0]
            count_op = (
                (n_iter + 1) * (8 * n * r + 8 * m * r + 4 * n + 4 * m + 27 * r)
                + 4 * n * r
                + 4 * m * r
            )
            return Q, R, g, count_op

    n, m = np.shape(C1)[0], np.shape(C2)[0]
    count_op = (
        (n_iter + 1) * (8 * n * r + 8 * m * r + 4 * n + 4 * m + 27 * r)
        + 4 * n * r
        + 4 * m * r
    )
    return Q, R, g, count_op


def LR_IBP_Sin(K1, K2, K3, a, b, max_iter=1000, delta=1e-9, lam=0):
    Q = K1
    R = K2
    g = K3

    r = np.shape(K3)[0]
    v1, v2 = np.ones(r), np.ones(r)
    u1, u2 = np.ones(np.shape(a)[0]), np.ones(np.shape(a)[0])

    u1_trans = np.dot(K1, v1)  # n * r
    u2_trans = np.dot(K2, v2)  # m * r

    err = 1
    n_iter = 0
    while n_iter < max_iter:
        u1_prev, v1_prev = u1, v1
        u2_prev, v2_prev = u2, v2
        g_prev = g
        if err > delta:
            n_iter = n_iter + 1

            # Update u1
            u1 = a / u1_trans  # n
            v1_trans = np.dot(K1.T, u1)  # n * r

            # Update u2
            u2 = a / u2_trans  # m
            v2_trans = np.dot(K2.T, u2)  # m * r

            # Update g
            # g = g / np.sum(g)
            g = (g * v1 * v1_trans * v2 * v2_trans) ** (1 / 3)  # 5 * r

            # Update v1
            v1 = g / v1_trans  # r

            # Update v2
            v2 = g / v2_trans  # r

            # Update the couplings
            # Q = u1.reshape((-1, 1)) * K1 * v1.reshape((1, -1))
            # R = u2.reshape((-1, 1)) * K2 * v2.reshape((1, -1))

            # Update the error
            u1_trans = np.dot(K1, v1)
            err_1 = np.sum(np.abs(u1 * u1_trans - a))
            u2_trans = np.dot(K2, v2)
            err_2 = np.sum(np.abs(u2 * u2_trans - b))
            err = err_1 + err_2

            if (
                np.any(np.isnan(u1))
                or np.any(np.isnan(v1))
                or np.any(np.isnan(u2))
                or np.any(np.isnan(v2))
                or np.any(np.isinf(u1))
                or np.any(np.isinf(v1))
                or np.any(np.isinf(u2))
                or np.any(np.isinf(v2))
            ):
                # we have reached the machine precision
                # come back to previous solution and quit loop
                print("Warning: numerical errors at iteration", n_iter)
                u1, v1 = u1_prev, v1_prev
                u2, v2 = u2_prev, v2_prev
                g = g_prev
                break
        else:
            Q = u1.reshape((-1, 1)) * K1 * v1.reshape((1, -1))
            R = u2.reshape((-1, 1)) * K2 * v2.reshape((1, -1))
            n, m = np.shape(K1)[0], np.shape(K2)[0]
            count_op = (
                (n_iter + 1) * (2 * n * r + 2 * m * r + 7 * r) + 3 * n * r + 3 * m * r
            )
            return Q, R, g, count_op

    Q = u1.reshape((-1, 1)) * K1 * v1.reshape((1, -1))
    R = u2.reshape((-1, 1)) * K2 * v2.reshape((1, -1))
    n, m = np.shape(K1)[0], np.shape(K2)[0]
    count_op = (n_iter + 1) * (2 * n * r + 2 * m * r + 7 * r) + 3 * n * r + 3 * m * r
    return Q, R, g, count_op


# Here cost is a function
# Here we have assumed that to compute each entries of thecost matrix it takes O(d)
def UpdatePlans(X, Y, Z, a, b, reg, cost, max_iter=1000, delta=1e-9, lam=0):

    C1 = cost(Z, X)  # d * n * r
    K1 = np.exp(-C1 / reg)  # size: r x n

    C2 = cost(Z, Y)  # d * m * r
    K2 = np.exp(-C2 / reg)  # size: r x m

    r = np.shape(Z)[0]
    u1, u2 = np.ones(r), np.ones(r)
    v1, v2 = np.ones(np.shape(a)[0]), np.ones(np.shape(b)[0])

    v1_trans = np.dot(K1.T, u1)  # r * n
    v2_trans = np.dot(K2.T, u2)  # r * m

    w = np.ones(r) / r  # r

    err = 1
    n_iter = 0
    while n_iter < max_iter:
        u1_prev, v1_prev = u1, v1
        u2_prev, v2_prev = u2, v2
        w_prev = w
        if err > delta:
            n_iter = n_iter + 1

            # Update v1, v2
            v1 = a / v1_trans  # n
            u1_trans = np.dot(K1, v1)  # n * r

            v2 = b / v2_trans  # m
            u2_trans = np.dot(K2, v2)  # m * r

            # Update w
            w = (u1 * u1_trans * u2 * u2_trans) ** (1 / 2)  # 4 * r

            # Update u1, u2
            u1 = w / u1_trans  # r
            u2 = w / u2_trans  # r

            # Update the error
            v1_trans = np.dot(K1.T, u1)  # n * r
            err_1 = np.sum(np.abs(v1 * v1_trans - a))
            v2_trans = np.dot(K2.T, u2)  # n * r
            err_2 = np.sum(np.abs(v2 * v2_trans - b))
            err = err_1 + err_2

            if (
                np.any(np.isnan(u1))
                or np.any(np.isnan(v1))
                or np.any(np.isnan(u2))
                or np.any(np.isnan(v2))
                or np.any(np.isinf(u1))
                or np.any(np.isinf(v1))
                or np.any(np.isinf(u2))
                or np.any(np.isinf(v2))
            ):
                # we have reached the machine precision
                # come back to previous solution and quit loop
                print("Warning: numerical errors at iteration", n_iter)
                u1, v1 = u1_prev, v1_prev
                u2, v2 = u2_prev, v2_prev
                w = w_prev
                break
        else:
            gamma_1 = u1.reshape((-1, 1)) * K1 * v1.reshape((1, -1))
            gamma_2 = u2.reshape((-1, 1)) * K2 * v2.reshape((1, -1))
            n, m, d = np.shape(X)[0], np.shape(Y)[0], np.shape(Z)[1]
            count_op = (
                (n_iter + 1) * (2 * n * r + 2 * m * r + 6 * r + n + m)
                + (d + 2) * n * r
                + (d + 2) * m * r
                + r
            )
            return gamma_1, gamma_2, w, count_op

    gamma_1 = u1.reshape((-1, 1)) * K1 * v1.reshape((1, -1))
    gamma_2 = u2.reshape((-1, 1)) * K2 * v2.reshape((1, -1))
    n, m, d = np.shape(X)[0], np.shape(Y)[0], np.shape(Z)[1]
    count_op = (
        (n_iter + 1) * (2 * n * r + 2 * m * r + 6 * r + n + m)
        + (d + 2) * n * r
        + (d + 2) * m * r
        + r
    )
    return gamma_1, gamma_2, w, count_op


# gamma_init = 'theory', 'regularization', 'arbitrary'
# method = 'IBP', 'Dykstra', 'Dykstra_LSE'
# If C_init = True: cost_factorized = C1,C2
# If C_init = False: cost_factorized is a function
# Init = 'trivial', kmeans', 'random'
def Lin_LOT_MD(
    X,
    Y,
    a,
    b,
    rank,
    reg,
    alpha,
    cost,
    cost_factorized,
    Init="trivial",
    seed_init=49,
    C_init=False,
    reg_init=1e-1,
    gamma_init="theory",
    gamma_0=1e-1,
    method="IBP",
    max_iter=1000,
    delta=1e-3,
    max_iter_IBP=1000,
    delta_IBP=1e-9,
    lam_IBP=0,
    time_out=200,
):
    start = time.time()
    num_op = 0
    acc = []
    times = []
    list_num_op = []

    r = rank
    n, m = np.shape(a)[0], np.shape(b)[0]

    if C_init == False:
        C = cost_factorized(X, Y)
        if len(C) == 2:
            C1, C2 = C
        else:
            print("Error: cost not adapted")
            return "Error"
    else:
        C1, C2 = cost_factorized

    n, d = np.shape(C1)

    ########### Initialization ###########

    ## Init with K-means
    if Init == "kmeans":
        g = np.ones(rank) / rank
        kmeans = KMeans(n_clusters=rank, random_state=0).fit(X)
        Z = kmeans.cluster_centers_
        num_iter_kmeans = kmeans.n_iter_
        num_op = num_op + r + num_iter_kmeans * r * n
        reg_init = reg_init
        gamma1, gamma2, g, count_op_Barycenter = UpdatePlans(
            X,
            Y,
            Z,
            a,
            b,
            reg_init,
            cost,
            max_iter=max_iter_IBP,
            delta=delta_IBP,
            lam=lam_IBP,
        )
        Q, R = gamma1.T, gamma2.T
        num_op = num_op + count_op_Barycenter

    ## Init random
    if Init == "random":
        np.random.seed(seed_init)
        g = np.abs(np.random.randn(rank))
        g = g + 1
        g = g / np.sum(g)
        n, d = np.shape(X)
        m, d = np.shape(Y)

        seed_init = seed_init + 1000
        np.random.seed(seed_init)
        Q = np.abs(np.random.randn(n, rank))
        Q = Q + 1
        Q = (Q.T * (a / np.sum(Q, axis=1))).T

        seed_init = seed_init + 1000
        np.random.seed(seed_init)
        R = np.abs(np.random.randn(m, rank))
        R = R + 1
        R = (R.T * (b / np.sum(R, axis=1))).T

        num_op = num_op + 2 * n * r + 2 * m * r + m + n + 2 * r

    ## Init trivial
    if Init == "trivial":
        g = np.ones(rank) / rank
        lambda_1 = min(np.min(a), np.min(g), np.min(b)) / 2

        a1 = np.arange(1, np.shape(a)[0] + 1)
        a1 = a1 / np.sum(a1)
        a2 = (a - lambda_1 * a1) / (1 - lambda_1)

        b1 = np.arange(1, np.shape(b)[0] + 1)
        b1 = b1 / np.sum(b1)
        b2 = (b - lambda_1 * b1) / (1 - lambda_1)

        g1 = np.arange(1, rank + 1)
        g1 = g1 / np.sum(g1)
        g2 = (g - lambda_1 * g1) / (1 - lambda_1)

        Q = lambda_1 * np.dot(a1[:, None], g1.reshape(1, -1)) + (1 - lambda_1) * np.dot(
            a2[:, None], g2.reshape(1, -1)
        )
        R = lambda_1 * np.dot(b1[:, None], g1.reshape(1, -1)) + (1 - lambda_1) * np.dot(
            b2[:, None], g2.reshape(1, -1)
        )

        num_op = num_op + 4 * n * r + 4 * m * r + 3 * n + 3 * m + 3 * r
    #####################################

    if gamma_init == "theory":
        L_trans = (
            (2 / (alpha) ** 4) * (np.linalg.norm(C1) ** 2) * (np.linalg.norm(C1) ** 2)
        )
        L_trans = (
            L_trans
            + ((reg + 2 * np.linalg.norm(C1) * np.linalg.norm(C1)) / (alpha ** 3)) ** 2
        )
        L = np.sqrt(3 * L_trans)
        gamma = 1 / L
        print(gamma)

    if gamma_init == "regularization":
        gamma = 1 / reg

    if gamma_init == "arbitrary":
        gamma = gamma_0

    err = 1
    niter = 0
    while niter < max_iter:
        Q_prev = Q
        R_prev = R
        g_prev = g
        if err > delta:
            niter = niter + 1

            K1_trans_0 = np.dot(C2, R)  # d * m * r
            K1_trans_0 = np.dot(C1, K1_trans_0)  # n * d * r
            C1_trans = K1_trans_0 / g + (reg - (1 / gamma)) * np.log(Q)  # 3 * n * r

            K2_trans_0 = np.dot(C1.T, Q)  # d * n * r
            K2_trans_0 = np.dot(C2.T, K2_trans_0)  # m * d * r
            C2_trans = K2_trans_0 / g + (reg - (1 / gamma)) * np.log(R)  # 3 * m * r

            omega = np.diag(np.dot(Q.T, K1_trans_0))  # r * n * r
            C3_trans = (omega / (g ** 2)) - (reg - (1 / gamma)) * np.log(g)  # 4 * r

            num_op = (
                num_op
                + 2 * n * d * r
                + 2 * m * d * r
                + r * n * r
                + 3 * n * r
                + 3 * m * r
                + 4 * r
            )

            # Update the coupling
            if method == "IBP":
                K1 = np.exp((-gamma) * C1_trans)
                K2 = np.exp((-gamma) * C2_trans)
                K3 = np.exp(gamma * C3_trans)
                Q, R, g = LR_IBP_Sin(
                    K1,
                    K2,
                    K3,
                    a,
                    b,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

            if method == "Dykstra":
                K1 = np.exp((-gamma) * C1_trans)
                K2 = np.exp((-gamma) * C2_trans)
                K3 = np.exp(gamma * C3_trans)
                num_op = num_op + 2 * n * r + 2 * m * r + 2 * r
                Q, R, g, count_op_Dysktra = LR_Dykstra_Sin(
                    K1,
                    K2,
                    K3,
                    a,
                    b,
                    alpha,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

                num_op = num_op + count_op_Dysktra

            if method == "Dykstra_LSE":
                Q, R, g, count_op_Dysktra_LSE = LR_Dykstra_LSE_Sin(
                    C1_trans,
                    C2_trans,
                    C3_trans,
                    a,
                    b,
                    alpha,
                    gamma,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

                num_op = num_op + count_op_Dysktra_LSE

            # Update the total cost

            # Metric used in the MIT paper: useless
            # OT_trans = compute_SE_OT(X,Y,Q,R,g)

            # Classical OT cost
            C_trans = np.dot(C2, R)
            C_trans = np.dot(C1, C_trans)
            C_trans = C_trans / g
            G = np.dot(Q.T, C_trans)
            OT_trans = np.trace(G)

            if niter > 10:
                ## Update the error: theoritical error
                # err_1 = ((1/gamma)**2) * (KL(Q,Q_prev) + KL(Q_prev,Q))
                # err_2 = ((1/gamma)**2) * (KL(R,R_prev) + KL(R_prev,R))
                # err_3 = ((1/gamma)**2) * (KL(g,g_prev) + KL(g_prev,g))
                # err = err_1 + err_2 + err_3

                ## Update the error: Practical error
                err = np.abs(OT_trans - acc[-1]) / acc[-1]

                if np.isnan(err):
                    print("Error computation of the stopping criterion", niter)
                    Q = Q_prev
                    R = R_prev
                    g = g_prev
                    break

            if np.isnan(OT_trans) == True:
                print("Error: NaN OT value")
                return "Error"

            else:
                acc.append(OT_trans)
                end = time.time()
                tim_actual = end - start
                times.append(end - start)
                list_num_op.append(num_op)
                if tim_actual > time_out:
                    return (
                        acc[-1],
                        np.array(acc),
                        np.array(times),
                        np.array(list_num_op),
                        Q,
                        R,
                        g,
                    )

        else:
            return (
                acc[-1],
                np.array(acc),
                np.array(times),
                np.array(list_num_op),
                Q,
                R,
                g,
            )

    return acc[-1], np.array(acc), np.array(times), np.array(list_num_op), Q, R, g


def update_Quad_cost_GW(D1, D2, Q, R, g):
    n, m = np.shape(D1)[0], np.shape(D2)[0]
    r = np.shape(g)[0]
    cost_trans_1 = np.dot(D1, Q)
    cost_trans_1 = -4 * cost_trans_1 / g
    cost_trans_2 = np.dot(R.T, D2)
    num_op = n * n * r + 2 * n * r + r * m * m
    return cost_trans_1, cost_trans_2, num_op


# gamma_init = 'theory', 'regularization', 'arbitrary'
# method = 'IBP', 'Dykstra', 'Dykstra_LSE'
# If C_init = True, cost is a tuple of matrices
# If C_init = False, cost is a function
# Init = 'trivial', 'random', 'lower_bound'
def Quad_LGW_MD(
    X,
    Y,
    a,
    b,
    rank,
    reg,
    alpha,
    cost,
    C_init=False,
    Init="trivial",
    seed_init=49,
    reg_init=1e-1,
    gamma_init="theory",
    gamma_0=1e-1,
    method="IBP",
    max_iter=1000,
    delta=1e-3,
    max_iter_IBP=1000,
    delta_IBP=1e-9,
    lam_IBP=0,
    time_out=200,
):
    start = time.time()
    num_op = 0
    acc = []
    times = []
    list_num_op = []
    Couplings = []

    r = rank
    n, m = np.shape(a)[0], np.shape(b)[0]

    if C_init == True:
        if len(cost) == 2:
            D1, D2 = cost
        else:
            print("Error: cost not adapted")
            return "Error"
    else:
        D1, D2 = cost(X, X), cost(Y, Y)

    ########### Initialization ###########

    ## Init Lower bound
    if Init == "lower_bound":
        X_new = np.sqrt(np.dot(D1 ** 2, a).reshape(-1, 1))  # 2 * n * n + n
        Y_new = np.sqrt(np.dot(D2 ** 2, b).reshape(-1, 1))  # 2 * m * m + m
        C1_init, C2_init = factorized_square_Euclidean(X_new, Y_new)  # 3 * m + 3 * n
        num_op = num_op + 2 * n * n + 2 * m * m + 4 * n + 4 * m
        cost_factorized_init = (C1_init, C2_init)
        cost_init = lambda z1, z2: Square_Euclidean_Distance(z1, z2)

        results = Lin_LOT_MD(
            X_new,
            Y_new,
            a,
            b,
            rank,
            reg,
            alpha,
            cost_init,
            cost_factorized_init,
            Init="kmeans",
            seed_init=seed_init,
            C_init=True,
            reg_init=reg_init,
            gamma_init="arbitrary",
            gamma_0=gamma_0,
            method=method,
            max_iter=max_iter,
            delta=delta_IBP,
            max_iter_IBP=max_iter_IBP,
            delta_IBP=delta_IBP,
            lam_IBP=lam_IBP,
            time_out=100,
        )

        if results == "Error":
            return "Error"

        else:
            res_init, acc_init, times_init, num_op_init, Q, R, g = results
            Couplings.append((Q, R, g))
            num_op = num_op + num_op_init[-1]

        # print('res: '+str(res_init))

    ## Init random
    if Init == "random":
        np.random.seed(seed_init)
        g = np.abs(np.random.randn(rank))
        g = g + 1  # r
        g = g / np.sum(g)  # r

        seed_init = seed_init + 1000
        np.random.seed(seed_init)
        Q = np.abs(np.random.randn(n, rank))
        Q = Q + 1  # n * r
        Q = (Q.T * (a / np.sum(Q, axis=1))).T  # n + 2 * n * r

        seed_init = seed_init + 1000
        np.random.seed(seed_init)
        R = np.abs(np.random.randn(m, rank))
        R = R + 1  # n * r
        R = (R.T * (b / np.sum(R, axis=1))).T  # m + 2 * m * r

        Couplings.append((Q, R, g))
        num_op = num_op + 2 * n * r + 2 * m * r + n + m + 2 * r

    ## Init trivial
    if Init == "trivial":
        g = np.ones(rank) / rank
        lambda_1 = min(np.min(a), np.min(g), np.min(b)) / 2

        a1 = np.arange(1, np.shape(a)[0] + 1)
        a1 = a1 / np.sum(a1)
        a2 = (a - lambda_1 * a1) / (1 - lambda_1)

        b1 = np.arange(1, np.shape(b)[0] + 1)
        b1 = b1 / np.sum(b1)
        b2 = (b - lambda_1 * b1) / (1 - lambda_1)

        g1 = np.arange(1, rank + 1)
        g1 = g1 / np.sum(g1)
        g2 = (g - lambda_1 * g1) / (1 - lambda_1)

        Q = lambda_1 * np.dot(a1[:, None], g1.reshape(1, -1)) + (1 - lambda_1) * np.dot(
            a2[:, None], g2.reshape(1, -1)
        )
        R = lambda_1 * np.dot(b1[:, None], g1.reshape(1, -1)) + (1 - lambda_1) * np.dot(
            b2[:, None], g2.reshape(1, -1)
        )

        Couplings.append((Q, R, g))
        num_op = num_op + 4 * n * r + 4 * m * r + 3 * n + 3 * m + 3 * r
    #####################################

    if gamma_init == "theory":
        gamma = 1  # to compute

    if gamma_init == "regularization":
        gamma = 1 / reg

    if gamma_init == "arbitrary":
        gamma = gamma_0

    c = np.dot(np.dot(D1 ** 2, a), a) + np.dot(
        np.dot(D2 ** 2, b), b
    )  # 2 * n * n + n + 2 * m * m + m
    C1, C2, num_op_update = update_Quad_cost_GW(D1, D2, Q, R, g)
    num_op = num_op + 2 * n * n + n + 2 * m * m + m + num_op_update

    # GW cost
    C_trans = np.dot(C2, R)
    C_trans = np.dot(C1, C_trans)
    C_trans = C_trans / g
    G = np.dot(Q.T, C_trans)
    OT_trans = np.trace(G)  # \langle -4DPD',P\rangle
    GW_trans = c + OT_trans / 2
    print(GW_trans)
    acc.append(GW_trans)
    end = time.time()
    tim_actual = end - start
    times.append(tim_actual)
    list_num_op.append(num_op)

    err = 1
    niter = 0
    while niter < max_iter:
        Q_prev = Q
        R_prev = R
        g_prev = g
        if err > delta:
            niter = niter + 1

            K1_trans_0 = np.dot(C2, R)  # r * m * r
            K1_trans_0 = np.dot(C1, K1_trans_0)  # n * r * r
            C1_trans = K1_trans_0 / g + (reg - (1 / gamma)) * np.log(Q)  # 3 * n * r

            K2_trans_0 = np.dot(C1.T, Q)  # r * n * r
            K2_trans_0 = np.dot(C2.T, K2_trans_0)  # m * r * r
            C2_trans = K2_trans_0 / g + (reg - (1 / gamma)) * np.log(R)  # 3 * m * r

            omega = np.diag(np.dot(Q.T, K1_trans_0))  # r * n * r
            C3_trans = (omega / (g ** 2)) - (reg - (1 / gamma)) * np.log(g)  # 4 * r

            num_op = (
                num_op + 3 * n * r * r + 2 * m * r * r + 3 * n * r + 3 * m * r + 4 * r
            )

            # Update the coupling
            if method == "IBP":
                K1 = np.exp((-gamma) * C1_trans)
                K2 = np.exp((-gamma) * C2_trans)
                K3 = np.exp(gamma * C3_trans)
                Q, R, g = LR_IBP_Sin(
                    K1,
                    K2,
                    K3,
                    a,
                    b,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

            if method == "Dykstra":
                K1 = np.exp((-gamma) * C1_trans)
                K2 = np.exp((-gamma) * C2_trans)
                K3 = np.exp(gamma * C3_trans)
                num_op = num_op + 2 * n * r + 2 * m * r + 2 * r
                Q, R, g, count_op_Dysktra = LR_Dykstra_Sin(
                    K1,
                    K2,
                    K3,
                    a,
                    b,
                    alpha,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

                num_op = num_op + count_op_Dysktra

            if method == "Dykstra_LSE":
                Q, R, g, count_op_Dysktra_LSE = LR_Dykstra_LSE_Sin(
                    C1_trans,
                    C2_trans,
                    C3_trans,
                    a,
                    b,
                    alpha,
                    gamma,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

                num_op = num_op + count_op_Dysktra_LSE

            # Update the total cost
            C1, C2, num_op_update = update_Quad_cost_GW(D1, D2, Q, R, g)
            num_op = num_op + num_op_update

            # GW cost
            C_trans = np.dot(C2, R)
            C_trans = np.dot(C1, C_trans)
            C_trans = C_trans / g
            G = np.dot(Q.T, C_trans)
            OT_trans = np.trace(G)  # \langle -4DPD',P\rangle
            GW_trans = c + OT_trans / 2
            print(GW_trans)

            if niter > 10:
                ## Update the error: theoritical error
                # err_1 = ((1/gamma)**2) * (KL(Q,Q_prev) + KL(Q_prev,Q))
                # err_2 = ((1/gamma)**2) * (KL(R,R_prev) + KL(R_prev,R))
                # err_3 = ((1/gamma)**2) * (KL(g,g_prev) + KL(g_prev,g))
                # err = err_1 + err_2 + err_3

                ## Update the error: Practical error
                err = np.abs(GW_trans - acc[-1]) / acc[-1]

                if np.isnan(err):
                    print("Error computation of the stopping criterion", niter)
                    Q = Q_prev
                    R = R_prev
                    g = g_prev
                    break

                # here we let the error to be one always !
                err = 1

            if np.isnan(OT_trans) == True:
                print("Error: NaN OT value")
                return "Error"

            else:
                acc.append(GW_trans)
                Couplings.append((Q, R, g))
                end = time.time()
                tim_actual = end - start
                times.append(tim_actual)
                list_num_op.append(num_op)
                if tim_actual > time_out:
                    return (
                        acc[-1],
                        np.array(acc),
                        np.array(times),
                        np.array(list_num_op),
                        Couplings,
                    )

        else:
            return (
                acc[-1],
                np.array(acc),
                np.array(times),
                np.array(list_num_op),
                Couplings,
            )

    return acc[-1], np.array(acc), np.array(times), np.array(list_num_op), Couplings


def update_Lin_cost_GW(D11, D12, D21, D22, Q, R, g):
    n, d1 = np.shape(D11)
    m, d2 = np.shape(D21)
    r = np.shape(g)[0]
    cost_trans_1 = np.dot(D12, Q)  # d1 * n * r
    cost_trans_1 = -4 * np.dot(
        D11, cost_trans_1 / g
    )  # n * d1 * r + d1 * r + n * r # size: n * r
    cost_trans_2 = np.dot(R.T, D21)  # r * m * d2
    cost_trans_2 = np.dot(cost_trans_2, D22)  # r * d2 * m # size: r * m
    num_op = 2 * n * r * d1 + 2 * r * d2 * m + d1 * r + n * r
    return cost_trans_1, cost_trans_2, num_op


# gamma_init = 'theory', 'regularization', 'arbitrary'
# method = 'IBP', 'Dykstra', 'Dykstra_LSE'
# If C_init = True, cost_factorized is a tuple of matrices (D11,D12,D21,D22)
# D1 = D11D12, D2 = D21D22
# If C_init = False, cost_factorized is a function
# Init = 'trivial', 'random', 'lower_bound'
def Lin_LGW_MD(
    X,
    Y,
    a,
    b,
    rank,
    reg,
    alpha,
    cost_factorized,
    C_init=False,
    Init="trivial",
    seed_init=49,
    reg_init=1e-1,
    gamma_init="arbitrary",
    gamma_0=1e-1,
    method="IBP",
    max_iter=1000,
    delta=1e-3,
    max_iter_IBP=1000,
    delta_IBP=1e-9,
    lam_IBP=0,
    time_out=200,
):
    start = time.time()
    num_op = 0
    acc = []
    times = []
    list_num_op = []
    Couplings = []

    if C_init == True:
        if len(cost_factorized) == 4:
            D11, D12, D21, D22 = cost_factorized
        else:
            print("Error: cost not adapted")
            return "Error"
    else:
        D11, D12 = cost_factorized(X, X)
        D21, D22 = cost_factorized(Y, Y)

    r = rank
    n, d1 = np.shape(D11)
    m, d2 = np.shape(D21)
    ########### Initialization ###########

    ## Init Lower bound
    if Init == "lower_bound":
        tilde_D11 = Feature_Map_Poly(D11)  # n * d1 * d1
        tilde_D12_T = Feature_Map_Poly(D12.T)  # n * d1 * d1
        tilde_D12 = tilde_D12_T.T

        tilde_D21 = Feature_Map_Poly(D21)  # m * d2 * d2
        tilde_D22_T = Feature_Map_Poly(D22.T)  # m * d2 * d2
        tilde_D22 = tilde_D22_T.T

        X_new = np.dot(tilde_D12, a)  # d1 * d1 * n
        X_new = np.sqrt(np.dot(tilde_D11, X_new).reshape(-1, 1))  # n * d1 * d1 + n
        Y_new = np.dot(tilde_D22, b)  # d2 * d2 * m
        Y_new = np.sqrt(np.dot(tilde_D21, Y_new).reshape(-1, 1))  # m * d2 * d2 + m

        C1_init, C2_init = factorized_square_Euclidean(X_new, Y_new)  # 3 * m + 3 * n
        cost_factorized_init = (C1_init, C2_init)

        num_op = num_op + 4 * n * d1 * d1 + 4 * m * d2 * d2 + 4 * n + 4 * n

        cost_init = lambda z1, z2: Square_Euclidean_Distance(z1, z2)
        results = Lin_LOT_MD(
            X_new,
            Y_new,
            a,
            b,
            rank,
            reg,
            alpha,
            cost_init,
            cost_factorized_init,
            Init="kmeans",
            seed_init=seed_init,
            C_init=True,
            reg_init=reg_init,
            gamma_init="arbitrary",
            gamma_0=gamma_0,
            method=method,
            max_iter=max_iter,
            delta=delta_IBP,
            max_iter_IBP=max_iter_IBP,
            delta_IBP=delta_IBP,
            lam_IBP=lam_IBP,
            time_out=5,
        )

        if results == "Error":
            return "Error"

        else:
            res_init, acc_init, times_init, num_op_init, Q, R, g = results
            Couplings.append((Q, R, g))
            num_op = num_op + num_op_init[-1]

        # print('res: '+str(res_init))

    ## Init random
    if Init == "random":
        np.random.seed(seed_init)
        g = np.abs(np.random.randn(rank))
        g = g + 1
        g = g / np.sum(g)
        n, d = np.shape(X)
        m, d = np.shape(Y)

        seed_init = seed_init + 1000
        np.random.seed(seed_init)
        Q = np.abs(np.random.randn(n, rank))
        Q = Q + 1
        Q = (Q.T * (a / np.sum(Q, axis=1))).T

        seed_init = seed_init + 1000
        np.random.seed(seed_init)
        R = np.abs(np.random.randn(m, rank))
        R = R + 1
        R = (R.T * (b / np.sum(R, axis=1))).T

        Couplings.append((Q, R, g))
        num_op = num_op + 2 * n * r + 2 * m * r + n + m + 2 * r

    ## Init trivial
    if Init == "trivial":
        g = np.ones(rank) / rank
        lambda_1 = min(np.min(a), np.min(g), np.min(b)) / 2

        a1 = np.arange(1, np.shape(a)[0] + 1)
        a1 = a1 / np.sum(a1)
        a2 = (a - lambda_1 * a1) / (1 - lambda_1)

        b1 = np.arange(1, np.shape(b)[0] + 1)
        b1 = b1 / np.sum(b1)
        b2 = (b - lambda_1 * b1) / (1 - lambda_1)

        g1 = np.arange(1, rank + 1)
        g1 = g1 / np.sum(g1)
        g2 = (g - lambda_1 * g1) / (1 - lambda_1)

        Q = lambda_1 * np.dot(a1[:, None], g1.reshape(1, -1)) + (1 - lambda_1) * np.dot(
            a2[:, None], g2.reshape(1, -1)
        )
        R = lambda_1 * np.dot(b1[:, None], g1.reshape(1, -1)) + (1 - lambda_1) * np.dot(
            b2[:, None], g2.reshape(1, -1)
        )

        Couplings.append((Q, R, g))
        num_op = num_op + 4 * n * r + 4 * m * r + 3 * n + 3 * m + 3 * r
    #####################################

    if gamma_init == "theory":
        gamma = 1  # to compute

    if gamma_init == "regularization":
        gamma = 1 / reg

    if gamma_init == "arbitrary":
        gamma = gamma_0

    tilde_D11 = Feature_Map_Poly(D11)  # n * d1 * d1
    tilde_D12_T = Feature_Map_Poly(D12.T)  # n * d1 * d1
    tilde_D12 = tilde_D12_T.T

    tilde_D21 = Feature_Map_Poly(D21)  # m * d2 * d2
    tilde_D22_T = Feature_Map_Poly(D22.T)  # m * d2 * d2
    tilde_D22 = tilde_D22_T.T

    a_tilde = np.dot(
        np.dot(tilde_D12, a), np.dot(np.transpose(tilde_D11), a)
    )  # 2 * d1 * d1 * n + d1 * d1
    b_tilde = np.dot(
        np.dot(tilde_D22, b), np.dot(np.transpose(tilde_D21), b)
    )  # 2 * m * d2 * d2 + d2 * d2
    c = a_tilde + b_tilde
    num_op = num_op + 4 * n * d1 * d1 + 4 * m * d2 * d2 + d1 * d1 + d2 * d2

    C1, C2, num_op_update = update_Lin_cost_GW(D11, D12, D21, D22, Q, R, g)
    num_op = num_op + num_op_update

    C_trans = np.dot(C2, R)
    C_trans = np.dot(C1, C_trans)
    C_trans = C_trans / g
    G = np.dot(Q.T, C_trans)
    OT_trans = np.trace(G)  # \langle -4DPD',P\rangle
    GW_trans = c + OT_trans / 2
    print(GW_trans)

    acc.append(GW_trans)
    end = time.time()
    tim_actual = end - start
    times.append(tim_actual)
    list_num_op.append(num_op)

    err = 1
    niter = 0
    while niter < max_iter:
        Q_prev = Q
        R_prev = R
        g_prev = g
        if err > delta:
            niter = niter + 1

            K1_trans_0 = np.dot(C2, R)  # r * m * r
            K1_trans_0 = np.dot(C1, K1_trans_0)  # n * r * r
            C1_trans = K1_trans_0 / g + (reg - (1 / gamma)) * np.log(Q)  # 3 * n * r

            K2_trans = np.dot(C1.T, Q)  # r * n * r
            K2_trans = np.dot(C2.T, K2_trans)  # m * r * r
            C2_trans = K2_trans / g + (reg - (1 / gamma)) * np.log(R)  # 3 * m * r

            omega = np.diag(np.dot(Q.T, K1_trans_0))  # r * n * r
            C3_trans = (omega / (g ** 2)) - (reg - (1 / gamma)) * np.log(g)  # 4 * r

            num_op = (
                num_op + 3 * n * r * r + 2 * m * r * r + 3 * n * r + 3 * m * r + 4 * r
            )

            # Update the coupling
            if method == "IBP":
                K1 = np.exp((-gamma) * C1_trans)
                K2 = np.exp((-gamma) * C2_trans)
                K3 = np.exp(gamma * C3_trans)
                Q, R, g = LR_IBP_Sin(
                    K1,
                    K2,
                    K3,
                    a,
                    b,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

            if method == "Dykstra":
                K1 = np.exp((-gamma) * C1_trans)
                K2 = np.exp((-gamma) * C2_trans)
                K3 = np.exp(gamma * C3_trans)
                num_op = num_op + 2 * n * r + 2 * m * r + 2 * r
                Q, R, g, count_op_Dysktra = LR_Dykstra_Sin(
                    K1,
                    K2,
                    K3,
                    a,
                    b,
                    alpha,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

                num_op = num_op + count_op_Dysktra

            if method == "Dykstra_LSE":
                Q, R, g, count_op_Dysktra_LSE = LR_Dykstra_LSE_Sin(
                    C1_trans,
                    C2_trans,
                    C3_trans,
                    a,
                    b,
                    alpha,
                    gamma,
                    max_iter=max_iter_IBP,
                    delta=delta_IBP,
                    lam=lam_IBP,
                )

                num_op = num_op + count_op_Dysktra_LSE
            # Update the total cost
            C1, C2, num_op_update = update_Lin_cost_GW(D11, D12, D21, D22, Q, R, g)
            num_op = num_op + num_op_update

            # GW cost
            C_trans = np.dot(C2, R)
            C_trans = np.dot(C1, C_trans)
            C_trans = C_trans / g
            G = np.dot(Q.T, C_trans)
            OT_trans = np.trace(G)  # \langle -4DPD',P\rangle
            GW_trans = c + OT_trans / 2
            print(GW_trans)

            if niter > 10:
                ## Update the error: theoritical error
                # err_1 = ((1/gamma)**2) * (KL(Q,Q_prev) + KL(Q_prev,Q))
                # err_2 = ((1/gamma)**2) * (KL(R,R_prev) + KL(R_prev,R))
                # err_3 = ((1/gamma)**2) * (KL(g,g_prev) + KL(g_prev,g))
                # err = err_1 + err_2 + err_3

                ## Update the error: Practical error
                err = np.abs(GW_trans - acc[-1]) / acc[-1]

                if np.isnan(err):
                    print("Error computation of the stopping criterion", niter)
                    Q = Q_prev
                    R = R_prev
                    g = g_prev
                    break

                # here we let the error to be one always !
                err = 1

            if np.isnan(GW_trans) == True:
                print("Error: NaN GW value")
                return "Error"

            else:
                acc.append(GW_trans)
                Couplings.append((Q, R, g))
                end = time.time()
                tim_actual = end - start
                times.append(tim_actual)
                list_num_op.append(num_op)
                if tim_actual > time_out:
                    return (
                        acc[-1],
                        np.array(acc),
                        np.array(times),
                        np.array(list_num_op),
                        Couplings,
                    )

        else:
            return (
                acc[-1],
                np.array(acc),
                np.array(times),
                np.array(list_num_op),
                Couplings,
            )

    return acc[-1], np.array(acc), np.array(times), np.array(list_num_op), Couplings


# D1 = A_1A_2 and D2 = B_1B_2
def GW_Init(A_1, A_2, B_1, B_2, p, q):
    tilde_A_1 = Feature_Map_Poly(A_1)
    tilde_A_2_T = Feature_Map_Poly(A_2.T)
    tilde_A_2 = tilde_A_2_T.T

    tilde_B_1 = Feature_Map_Poly(B_1)
    tilde_B_2_T = Feature_Map_Poly(B_2.T)
    tilde_B_2 = tilde_B_2_T.T

    tilde_a = np.dot(tilde_A_1, np.dot(tilde_A_2, p))
    tilde_b = np.dot(tilde_B_1, np.dot(tilde_B_2, q))

    c = np.dot(tilde_a, p) + np.dot(tilde_b, q)

    P1 = p[:, None]
    P2 = q[None, :]
    G_1 = np.dot(A_2, P1)
    G_2 = np.dot(P2, B_1)
    G = np.dot(G_1, G_2)
    G_1_1 = np.dot(B_2, P2.T)
    G_2_1 = np.dot(P1.T, A_1)
    G_trans = np.dot(G_1_1, G_2_1)

    M = np.dot(G, G_trans)
    res = c - 2 * np.trace(M)
    return res


def GW_Init_Cubic(D_1, D_2, a, b):
    P = a[:, None] * b[None, :]
    const_1 = np.dot(
        np.dot(D_1 ** 2, a.reshape(-1, 1)), np.ones(len(b)).reshape(1, -1)
    )  # 2 * n * n + n * m
    const_2 = np.dot(
        np.ones(len(a)).reshape(-1, 1), np.dot(b.reshape(1, -1), (D_2 ** 2).T)
    )  # 2 * m * m + n * m
    const = const_1 + const_2
    L = const - 2 * np.dot(np.dot(D_1, P), D_2)
    res = np.sum(L * P)
    return res


def Sinkhorn(C, reg, a, b, delta=1e-9, num_iter=1000, lam=1e-6):

    n, m = np.shape(C)
    # K = np.exp(-C/reg)
    # Next 3 lines equivalent to K= np.exp(-C/reg), but faster to compute
    K = np.empty(C.shape, dtype=C.dtype)
    np.divide(C, -reg, out=K)  # n * m
    np.exp(K, out=K)  # n * m

    u = np.ones(np.shape(a)[0])  # /np.shape(a)[0]
    v = np.ones(np.shape(b)[0])  # /np.shape(b)[0]

    v_trans = np.dot(K.T, u) + lam  # add regularization to avoid divide 0

    err = 1
    index = 0
    while index < num_iter:
        uprev = u
        vprev = v
        if err > delta:
            index = index + 1

            v = b / v_trans

            u_trans = np.dot(K, v) + lam  # add regularization to avoid divide 0
            u = a / u_trans

            if (
                np.any(np.isnan(u))
                or np.any(np.isnan(v))
                or np.any(np.isinf(u))
                or np.any(np.isinf(v))
            ):
                # we have reached the machine precision
                # come back to previous solution and quit loop
                print("Warning: numerical errors at iteration", index)
                u = uprev
                v = vprev
                break

            v_trans = np.dot(K.T, u) + lam
            err = np.sum(np.abs(v * v_trans - b))

        else:
            num_op = 3 * n * m + (index + 1) * (2 * n * m + n + m)
            return u, v, K, num_op

    num_op = 3 * n * m + (index + 1) * (2 * n * m + n + m)
    return u, v, K, num_op


def LSE_Sinkhorn(C, reg, a, b, num_iter=1000, delta=1e-3, lam=0):

    f = np.zeros(np.shape(a)[0])
    g = np.zeros(np.shape(b)[0])

    n, m = np.shape(C)

    C_tilde = f[:, None] + g[None, :] - C  # 2 * n * m
    C_tilde = C_tilde / reg  # n * m
    P = np.exp(C_tilde)

    err = 1
    n_iter = 0
    while n_iter < num_iter:
        P_prev = P
        if err > delta:
            n_iter = n_iter + 1

            # Update f
            f = reg * np.log(a) + f - reg * scipy.special.logsumexp(C_tilde, axis=1)

            # Update g
            C_tilde = f[:, None] + g[None, :] - C
            C_tilde = C_tilde / reg
            g = reg * np.log(b) + g - reg * scipy.special.logsumexp(C_tilde, axis=0)

            if (
                np.any(np.isnan(f))
                or np.any(np.isnan(g))
                or np.any(np.isinf(f))
                or np.any(np.isinf(g))
            ):
                # we have reached the machine precision
                # come back to previous solution and quit loop
                print("Warning: numerical errors at iteration", n_iter)
                P = P_prev
                break

            # Update the error
            C_tilde = f[:, None] + g[None, :] - C
            C_tilde = C_tilde / reg
            P = np.exp(C_tilde)
            err = np.sum(np.abs(np.sum(P, axis=1) - a))

        else:
            num_op = 4 * n * m + (n_iter + 1) * (8 * n * m + 3 * n + 3 * m) + n * m
            return P, num_op

    num_op = 4 * n * m + (n_iter + 1) * (8 * n * m + 3 * n + 3 * m) + n * m
    return P, num_op


#### CUBIC VERSION ####
## Stable version: works for every $\varepsilon ##
# Here the costs considered are C = 2 (constant - 2 DPD')
def GW_entropic_distance2(
    D_1,
    D_2,
    reg,
    a,
    b,
    Init="trivial",
    seed_init=49,
    I=10,
    delta_sin=1e-9,
    num_iter_sin=1000,
    lam_sin=0,
    LSE=False,
    time_out=50,
):
    start = time.time()
    num_op = 0
    acc = []
    times = []
    list_num_op = []
    Couplings = []

    n, m = np.shape(a)[0], np.shape(b)[0]

    if Init == "trivial":
        P = a[:, None] * b[None, :]
        Couplings.append(P)
        num_op = num_op + n * m

    if Init == "lower_bound":
        X_new = np.sqrt(np.dot(D_1 ** 2, a).reshape(-1, 1))  # 2 * n * n + n
        Y_new = np.sqrt(np.dot(D_2 ** 2, b).reshape(-1, 1))  # 2 * m * m + m
        C_init = Square_Euclidean_Distance(X_new, Y_new)  # n * m
        num_op = num_op + n * m + 2 * n * n + 2 * m * m + n + m

        if LSE == False:
            u, v, K, count_op_Sin = Sinkhorn(
                C_init, reg, a, b, delta=delta_sin, num_iter=num_iter_sin, lam=lam_sin
            )
            num_op = num_op + count_op_Sin
            P = u[:, None] * K * v[None, :]
            num_op = num_op + 2 * n * m
        else:
            P, count_op_Sin_LSE = LSE_Sinkhorn(
                C_init, reg, a, b, delta=delta_sin, num_iter=num_iter_sin, lam=lam_sin
            )
            num_op = num_op + count_op_Sin_LSE

        Couplings.append(P)

    if Init == "random":
        np.random.seed(seed_init)
        P = np.abs(np.random.randn(n, m))
        P = P + 1
        P = (P.T * (a / np.sum(P, axis=1))).T
        Couplings.append(P)
        num_op = num_op + 3 * n * m + n

    const_1 = np.dot(
        np.dot(D_1 ** 2, a.reshape(-1, 1)), np.ones(len(b)).reshape(1, -1)
    )  # 2 * n * n + n * m
    const_2 = np.dot(
        np.ones(len(a)).reshape(-1, 1), np.dot(b.reshape(1, -1), (D_2 ** 2).T)
    )  # 2 * m * m + n * m
    num_op = num_op + 2 * n * m + 2 * n * n + 2 * m * m
    const = const_1 + const_2
    L = const - 2 * np.dot(np.dot(D_1, P), D_2)

    res = np.sum(L * P)
    # print(res)
    end = time.time()
    curr_time = end - start
    times.append(curr_time)
    acc.append(res)
    list_num_op.append(num_op)

    for k in range(I):
        if LSE == False:
            u, v, K, count_op_Sin = Sinkhorn(
                2 * L, reg, a, b, delta=delta_sin, num_iter=num_iter_sin, lam=lam_sin
            )
            num_op = num_op + count_op_Sin
            P = u.reshape((-1, 1)) * K * v.reshape((1, -1))
            num_op = num_op + 2 * n * m
        else:
            P, count_op_Sin_LSE = LSE_Sinkhorn(
                2 * L, reg, a, b, delta=delta_sin, num_iter=num_iter_sin, lam=lam_sin
            )
            num_op = num_op + count_op_Sin_LSE

        L = const - 2 * np.dot(np.dot(D_1, P), D_2)
        num_op = num_op + n * n * m + n * m * m + 2 * n * m
        res = np.sum(L * P)
        # print(res)

        if np.isnan(res) == True:
            return "Error"
        else:
            acc.append(res)
            Couplings.append(P)

        end = time.time()
        curr_time = end - start
        times.append(curr_time)
        list_num_op.append(num_op)
        if curr_time > time_out:
            return (
                acc[-1],
                np.array(acc),
                np.array(times),
                np.array(list_num_op),
                Couplings,
            )

    return acc[-1], np.array(acc), np.array(times), np.array(list_num_op), Couplings


#### QUAD VERSION ####
## Stable version: works for every $\varepsilon ##
# Here the costs considered are C = 2 (constant - 2 DPD')
def Quad_GW_entropic_distance2(
    A_1,
    A_2,
    B_1,
    B_2,
    reg,
    a,
    b,
    Init="trivial",
    seed_init=49,
    I=10,
    delta_sin=1e-9,
    num_iter_sin=1000,
    lam_sin=0,
    time_out=50,
    LSE=False,
):
    start = time.time()
    num_op = 0

    acc = []
    times = []
    list_num_op = []
    Couplings = []

    n, d1 = np.shape(A_1)
    m, d2 = np.shape(B_1)

    tilde_A_1 = Feature_Map_Poly(A_1)
    tilde_A_2_T = Feature_Map_Poly(A_2.T)
    tilde_A_2 = tilde_A_2_T.T

    tilde_B_1 = Feature_Map_Poly(B_1)
    tilde_B_2_T = Feature_Map_Poly(B_2.T)
    tilde_B_2 = tilde_B_2_T.T

    num_op = num_op + 2 * n * d1 * d1 + 2 * m * d2 * d2

    tilde_a = np.dot(tilde_A_1, np.dot(tilde_A_2, a))  # 2 * d1 * d1 * n
    tilde_b = np.dot(tilde_B_1, np.dot(tilde_B_2, b))  # 2 * d2 * d2 * m

    c = np.dot(tilde_a, a) + np.dot(tilde_b, b)  # n + m

    const_1 = np.dot(tilde_a.reshape(-1, 1), np.ones(len(b)).reshape(1, -1))  # n * m
    const_2 = np.dot(np.ones(len(a)).reshape(-1, 1), tilde_b.reshape(1, -1))  # n * m
    const = const_1 + const_2

    num_op = num_op + 2 * d1 * d1 * n + 2 * d2 * d2 * m + 3 * n * m

    if Init == "trivial":
        P = a[:, None] * b[None, :]
        Couplings.append(P)
        num_op = num_op + n * m

    if Init == "lower_bound":
        X_new = np.dot(tilde_A_2, a)
        X_new = np.sqrt(np.dot(tilde_A_1, X_new).reshape(-1, 1))
        Y_new = np.dot(tilde_B_2, b)
        Y_new = np.sqrt(np.dot(tilde_B_1, Y_new).reshape(-1, 1))

        C_init = Square_Euclidean_Distance(X_new, Y_new)
        num_op = num_op + n * m + 2 * d1 * d1 * n + 2 * d2 * d2 * m + n + m

        if LSE == False:
            u, v, K, count_op_Sin = Sinkhorn(
                C_init, reg, a, b, delta=delta_sin, num_iter=num_iter_sin, lam=lam_sin
            )
            num_op = num_op + count_op_Sin
            P = u[:, None] * K * v[None, :]
            num_op = num_op + 2 * n * m
        else:
            P, count_op_Sin_LSE = LSE_Sinkhorn(
                C_init, reg, a, b, delta=delta_sin, num_iter=num_iter_sin, lam=lam_sin
            )
            num_op = num_op + count_op_Sin_LSE

        Couplings.append(P)

    if Init == "random":
        np.random.seed(seed_init)
        P = np.abs(np.random.randn(n, m))
        P = P + 1
        P = (P.T * (a / np.sum(P, axis=1))).T
        Couplings.append(P)
        num_op = num_op + 3 * n * m + n

    C_trans = np.dot(np.dot(A_2, P), B_1)  # d1 * n * m + d1 * m * d2
    num_op = num_op + d1 * n * m + d1 * d2 * m

    C_trans_2 = np.dot(np.dot(B_2, P.T), A_1)
    C_f = np.dot(C_trans_2, C_trans)
    res = c - 2 * np.trace(C_f)
    print(res)

    acc.append(res)
    end = time.time()
    curr_time = end - start
    times.append(curr_time)
    list_num_op.append(num_op)

    L = const - 2 * np.dot(
        np.dot(A_1, C_trans), B_2
    )  # n * m + n * d1 * d2 + n * d2 * m
    num_op = num_op + n * m + n * d1 * d2 + n * d2 * m

    for k in range(I):
        if LSE == False:
            u, v, K, count_op_Sin = Sinkhorn(
                2 * L, reg, a, b, delta=delta_sin, num_iter=num_iter_sin, lam=lam_sin
            )
            P = u.reshape((-1, 1)) * K * v.reshape((1, -1))
            num_op = num_op + count_op_Sin + 2 * n * m
        else:
            P, count_op_Sin_LSE = LSE_Sinkhorn(
                2 * L, reg, a, b, delta=delta_sin, num_iter=num_iter_sin, lam=lam_sin
            )
            num_op = num_op + count_op_Sin_LSE

        C_trans = np.dot(np.dot(A_2, P), B_1)
        L = const - 2 * np.dot(np.dot(A_1, C_trans), B_2)
        num_op = num_op + d1 * n * m + d2 * n * m + d1 * d2 * n + d1 * d2 * m + n * m

        C_trans_2 = np.dot(np.dot(B_2, P.T), A_1)
        C_f = np.dot(C_trans_2, C_trans)
        res = c - 2 * np.trace(C_f)
        print(res)

        if np.isnan(res) == True:
            return "Error"
        else:
            acc.append(res)
            Couplings.append(P)

        end = time.time()
        curr_time = end - start
        times.append(curr_time)
        list_num_op.append(num_op)
        if curr_time > time_out:
            return (
                acc[-1],
                np.array(acc),
                np.array(times),
                np.array(list_num_op),
                Couplings,
            )

    return acc[-1], np.array(acc), np.array(times), np.array(list_num_op), Couplings


## Feature map of k(x,y) = \langle x,y\rangle ** 2 ##
def Feature_Map_Poly(X):
    n, d = np.shape(X)
    X_new = np.zeros((n, d ** 2))
    for i in range(n):
        x = X[i, :][:, None]
        X_new[i, :] = np.dot(x, x.T).reshape(-1)
    return X_new


def Square_Euclidean_Distance(X, Y):
    """Returns the matrix of $|x_i-y_j|^2$."""
    X_col = X[:, np.newaxis]
    Y_lin = Y[np.newaxis, :]
    C = np.sum((X_col - Y_lin) ** 2, 2)
    # D = (np.sum(X ** 2, 1)[:, np.newaxis] - 2 * np.dot(X, Y.T) + np.sum(Y ** 2, 1))
    return C


# shape of xs: num_samples * dimension
def factorized_square_Euclidean(xs, xt):

    square_norm_s = np.sum(xs ** 2, axis=1)  # 2 * n * d
    square_norm_t = np.sum(xt ** 2, axis=1)  # 2 * m * d
    A_1 = np.zeros((np.shape(xs)[0], 2 + np.shape(xs)[1]))
    A_1[:, 0] = square_norm_s
    A_1[:, 1] = np.ones(np.shape(xs)[0])
    A_1[:, 2:] = -2 * xs  # n * d

    A_2 = np.zeros((2 + np.shape(xs)[1], np.shape(xt)[0]))
    A_2[0, :] = np.ones(np.shape(xt)[0])
    A_2[1, :] = square_norm_t
    A_2[2:, :] = xt.T

    return A_1, A_2


def Euclidean_Distance(X, Y):
    X_col = X[:, np.newaxis]
    Y_lin = Y[np.newaxis, :]
    C = np.sum((X_col - Y_lin) ** 2, 2)
    C = np.sqrt(C)
    # D = (np.sum(X ** 2, 1)[:, np.newaxis] - 2 * np.dot(X, Y.T) + np.sum(Y ** 2, 1))
    return C


def Lp_Distance(X, Y, p=1):
    X_col = X[:, np.newaxis]
    Y_lin = Y[np.newaxis, :]
    C = np.sum(np.abs(X_col - Y_lin) ** p, 2)
    C = C ** (1 / p)
    # D = (np.sum(X ** 2, 1)[:, np.newaxis] - 2 * np.dot(X, Y.T) + np.sum(Y ** 2, 1))
    return C


def Learning_linear_subspace(X, Y, cost, U, C_init=False, tol=1e-3):
    rank, m = np.shape(U)
    U_sym = np.dot(U, U.T)  # k x k
    d, v = np.linalg.eigh(U_sym)
    v = v / np.sqrt(d)  # k x k

    ind_column = np.random.choice(m, size=int(rank / tol))
    U_trans = U[:, ind_column]  # k x k/tol

    if C_init == False:
        A_trans = cost(X, Y[ind_column, :])
    else:
        A_trans = cost[:, ind_column]  # n x k/tol

    A_trans = (1 / np.sqrt(int(rank / tol))) * A_trans
    B = (1 / np.sqrt(int(rank / tol))) * np.dot(v.T, U_trans)  # k x k/tol
    Mat = np.linalg.inv(np.dot(B, B.T))
    Mat = np.dot(Mat, B)  # k x k/tol
    alpha = np.dot(Mat, A_trans.T)  # k x n

    V_f = np.dot(alpha.T, v.T)

    return V_f


# If C_init == True: cost is the Matrix
# If C_init == False: cost is the Function
def factorized_distance_cost(X, Y, rank, cost, C_init=False, tol=1e-3, seed=49):
    np.random.seed(seed)
    if C_init == False:
        n, m = np.shape(X)[0], np.shape(Y)[0]
    else:
        n, m = np.shape(cost)

    i_ = np.random.randint(n, size=1)
    j_ = np.random.randint(m, size=1)

    if C_init == False:
        X_trans = X[i_, :].reshape(1, -1)
        cost_trans_i = cost(X_trans, Y)
        mean = np.mean(cost_trans_i ** 2)
    else:
        cost_trans_i = cost[i_, :]
        mean = np.mean(cost_trans_i ** 2)

    if C_init == False:
        Y_trans = Y[j_, :].reshape(1, -1)
        cost_trans_j = cost(X, Y_trans)
    else:
        cost_trans_j = cost[:, j_]

    p_row = cost_trans_j ** 2 + cost_trans_i[0, j_] ** 2 + mean
    p_row = p_row / np.sum(p_row)  # vector of size n

    # Compute S
    ind_row = np.random.choice(n, size=int(rank / tol), p=p_row.reshape(-1))
    if C_init == False:
        S = cost(X[ind_row, :], Y)  # k/tol x m
    else:
        S = cost[ind_row, :]

    p_row_sub = p_row[ind_row]
    S = S / np.sqrt(int(rank / tol) * p_row_sub)

    norm_square_S = np.sum(S ** 2)
    p_column = np.zeros(m)
    for j in range(m):
        p_column[j] = np.sum(S[:, j] ** 2) / norm_square_S
    # p_column = p_column / np.sum(p_column) # vector of size m

    # Compute W
    ind_column = np.random.choice(m, size=int(rank / tol), p=p_column.reshape(-1))
    W = S[:, ind_column]  # k/tol x k/tol
    p_column_sub = p_column[ind_column]
    W = (W.T / np.sqrt(int(rank / tol) * p_column_sub)).T

    # Compute U
    u, d, v = np.linalg.svd(W)
    U = u[:, :rank]  # k/tol x k
    U_trans = np.dot(W.T, U)  # k/tol x k

    norm_U = np.sum(U_trans ** 2, axis=0)
    norm_U = np.sqrt(norm_U)

    U = np.dot(S.T, U)  # m x k
    U = U / norm_U

    # Compute V
    V = Learning_linear_subspace(X, Y, cost, U.T, C_init=C_init, tol=tol)

    return V, U.T
