
import torch
import FRLC.util as util
import FRLC.objective_grad as gd
import matplotlib.pyplot as plt


def FRLC_opt(C,
             a=None,
             b=None,
             A=None,
             B=None,
             tau_in=50,
             tau_out=50,
             gamma=70, 
             r = 10, 
             r2=None, 
             max_iter=200, 
             device='cpu', 
             dtype=torch.float64, 
             semiRelaxedLeft=False, 
             semiRelaxedRight=False, 
             Wasserstein=True, 
             printCost=False, 
             returnFull=False, 
             FGW=False, 
             alpha=0.0, 
             unbalanced=False, 
             initialization='Full',
             init_args = None,
             full_grad=True,
             convergence_criterion=True,
             tol=1e-5,
             min_iter = 25,
             min_iterGW = 500,
             max_iterGW = 1000,
             max_inneriters_balanced= 500,
             max_inneriters_relaxed= 500,
             diagonalize_return=False):
    
    '''
    FRLC: Factor Relaxation with Latent Coupling
    ------Parameters------
    C: torch.tensor (N1 x N2)
        A matrix of pairwise feature distances in space X and space Y (inter-space).
    a: torch.tensor (N1)
        A vector representing marginal one.
    b: torch.tensor (N2)
        A vector representing marginal two.
    A: torch.tensor (N1 x N1)
        A matrix of pairwise distances between points in metric space X.
    B: torch.tensor (N2 x N2)
        A matrix of pairwise distances between points in metric space Y.
    tau_in: float (> 0)
        A scalar which controls the regularity of the inner marginal update path.
    tau_out: float (> 0)
        A scalar which controls the regularization of the outer marginals a and b (e.g. for semi-relaxed or unbalanced OT)
    gamma: float (> 0)
        The mirror descent step-size, a scalar which controls the scaling of gradients
        before being exponentiated into Sinkhorn kernels.
    r: int (> 1)
        A non-negative integer rank, controlling the rank of the FRLC learned OT coupling. 
    max_iter: int
        The maximal number of iterations FRLC will run until convergence.
    device: str
        The device (i.e. 'cpu' or 'cuda') which FRLC runs on.
    dtype: dtype
        The datatype all tensors are stored on (naturally there is a space-accuracy
        tradeoff for low-rank between 32 and 64 bit).
    semiRelaxedLeft: bool
        True if running the left-marginal relaxed low-rank algorithm.
    semiRelaxedRight: bool
        True if running the right-marginal relaxed low-rank algorithm.
    Wasserstein: bool
        True if using the Wasserstein loss <C, P>_F as the objective cost,
        else runs GW if FGW false and FGW if GW true.
    printCost: bool
        True if printing the value of the objective cost at each iteration.
        This can be expensive for large datasets if C is not factored.
    returnFull: bool
        True if returning P_r = Q Lambda R.T, else returns iterates (Q, R, T).
    FGW: bool
        True if running the Fused-Gromov Wasserstein problem, and otherwise false.
    alpha: float
        A balance parameter between the Wasserstein term and
        the Gromov-Wasserstein term of the objective.
    unbalanced: bool
        True if running the unbalanced problem;
        if semiRelaxedLeft/Right and unbalanced False (default) then running the balanced problem.
    initialization: str, 'Full' or 'Rank-2'
        'Full' if sub-couplings initialized to be full-rank, if 'Rank-2' set to a rank-2 initialization.
        We advise setting this to be 'Full'.
    init_args: tuple of 3-tensors
        A tuple of (Q0, R0, T0) for tuple[i] of type tensor
    full_grad: bool
        If True, evaluates gradient with rank-1 perturbations.
        Else if False, omits perturbation terms.
    convergence_criterion: bool
        If True, use the convergence criterion. Else if False, default to running up to max_iters.
    tol: float
        Tolerance used for established when convergence is reached.
    min_iter: int
        The minimum iterations for the algorithm to run for in the Wasserstein case.
    min_iterGW: int
        The minimum number of iterations to run for in the GW case.
    max_iterGW: int
        The maximum number of iterations to run for in the GW case.
    max_inneriters_balanced: int
        The maximum number of inner iterations for the Sinkhorn loop.
    max_inneriters_relaxed: int
        The maximum number of inner iterations for the relaxed and semi-relaxed loops.
    diagonalize_return: bool
        If True, diagonalize the LC-factorization to the form of Scetbon et al '21.
        Else if False, return the LC-factorization.
    '''

    
    N1, N2 = C.size(dim=0), C.size(dim=1)
    k = 0
    stationarity_gap = torch.inf
    
    one_N1 = torch.ones((N1), device=device, dtype=dtype)
    one_N2 = torch.ones((N2), device=device, dtype=dtype)
    
    if a is None:
        a = one_N1 / N1
    if b is None:
        b = one_N2 / N2
    if r2 is None:
        r2 = r

    one_r = torch.ones((r), device=device, dtype=dtype)
    one_r2 = torch.ones((r2), device=device, dtype=dtype)
    
    # Initialize inner marginals to uniform; 
    # generalized to be of differing dimensions to account for non-square latent-coupling.
    gQ = (1/r)*one_r
    gR = (1/r2)*one_r2
    
    full_rank = True if initialization == 'Full' else False
    
    if initialization == 'Full':
        full_rank = True
    elif initialization == 'Rank-2':
        full_rank = False
    else:
        full_rank = True
        print('Initialization must be either "Full" or "Rank-2", defaulting to "Full".')
        
    if init_args is None:
        Q, R, T, Lambda = util.initialize_couplings(a, b, gQ, gR, \
                                                    gamma, full_rank=full_rank, \
                                                device=device, dtype=dtype, \
                                                    max_iter = max_inneriters_balanced)
        
        Q = stabilize_Q_init(Q, device=device, dtype=dtype)
        R = stabilize_Q_init(R, device=device, dtype=dtype)
        T = stabilize_Q_init(T, device=device, dtype=dtype)
        Lambda = torch.diag(1/ (Q.T @ one_N1)) @ T @ torch.diag(1/ (R.T @ one_N2))
        
    else:
        Q, R, T = init_args
        Q = stabilize_Q_init(Q, device=device, dtype=dtype)
        R = stabilize_Q_init(R, device=device, dtype=dtype) 
        T = stabilize_Q_init(T, device=device, dtype=dtype)
        Lambda = torch.diag(1/ (Q.T @ one_N1)) @ T @ torch.diag(1/ (R.T @ one_N2))

    if Wasserstein is False:
        min_iter = min_iterGW
        max_iter = max_iterGW

    '''
    Preparing main loop.
    '''
    errs = []
    grad = torch.inf
    gamma_k = gamma
    
    Q_prev, R_prev, T_prev = None, None, None
    
    # Initialize duals for warm-start across iterations
    dual_1Q, dual_2Q = None, None
    dual_1R, dual_2R = None, None
    dual_1T, dual_2T = None, None
    
    while (k < max_iter and (not convergence_criterion or \
                       (k < min_iter or util.Delta((Q, R, T), (Q_prev, R_prev, T_prev), gamma_k) > tol))):
        
        if convergence_criterion:
            # Set previous iterates to evaluate convergence at the next round
            Q_prev, R_prev, T_prev = Q, R, T
        
        if k % 25 == 0:
            print(f'Iteration: {k}')
        
        gradQ, gradR, gamma_k = gd.compute_grad_A(C, Q, R, Lambda, gamma, semiRelaxedLeft, \
                                               semiRelaxedRight, device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha, \
                                                  unbalanced=unbalanced, full_grad=full_grad)
        if semiRelaxedLeft:
            
            R, dual_1R, dual_2R = util.logSinkhorn(gradR - (gamma_k**-1)*torch.log(R), b, gR, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=False, tau=tau_in, dual_1 = dual_1R, dual_2 = dual_2R)
            Q, dual_1Q, dual_2Q = util.logSinkhorn(gradQ - (gamma_k**-1)*torch.log(Q), a, gQ, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=True, tau=tau_out, tau2=tau_in, \
                                 dual_1 = dual_1Q, dual_2 = dual_2Q)
            
            gQ, gR = Q.T @ one_N1, R.T @ one_N2
            
            gradT, gamma_T = gd.compute_grad_B(C, Q, R, Lambda, gQ, gR, \
                                               gamma, device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha)
            
        elif semiRelaxedRight:
            
            Q, dual_1Q, dual_2Q = util.logSinkhorn(gradQ - (gamma_k**-1)*torch.log(Q), a, gQ, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=False, tau=tau_in, \
                                 dual_1 = dual_1Q, dual_2 = dual_2Q)
            R, dual_1R, dual_2R = util.logSinkhorn(gradR - (gamma_k**-1)*torch.log(R), b, gR, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=True, tau=tau_out, tau2=tau_in, \
                                 dual_1 = dual_1R, dual_2 = dual_2R)
            
            gQ, gR = Q.T @ one_N1, R.T @ one_N2
            
            gradT, gamma_T = gd.compute_grad_B(C, Q, R, Lambda, gQ, gR, \
                                               gamma, device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha)
            
        elif unbalanced:
            
            Q, dual_1Q, dual_2Q = util.logSinkhorn(gradQ - (gamma_k**-1)*torch.log(Q), a, gQ, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=True, tau=tau_out, tau2=tau_in, \
                                dual_1 = dual_1Q, dual_2 = dual_2Q)
            
            R, dual_1R, dual_2R = util.logSinkhorn(gradR - (gamma_k**-1)*torch.log(R), b, gR, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=True, tau=tau_out, tau2=tau_in, \
                                dual_1 = dual_1R, dual_2 = dual_2R)
            
            gQ, gR = Q.T @ one_N1, R.T @ one_N2
            
            gradT, gamma_T = gd.compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, \
                                               device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha)
            
        else:
            
            # Balanced
            Q, dual_1Q, dual_2Q = util.logSinkhorn(gradQ - (gamma_k**-1)*torch.log(Q), a, gQ, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=False, tau=tau_in, \
                                dual_1 = dual_1Q, dual_2 = dual_2Q)
            
            R, dual_1R, dual_2R = util.logSinkhorn(gradR - (gamma_k**-1)*torch.log(R), b, gR, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=False, tau=tau_in, \
                                dual_1 = dual_1R, dual_2 = dual_2R)
            
            gQ, gR = Q.T @ one_N1, R.T @ one_N2
            
            gradT, gamma_T = gd.compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, \
                                               device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha)
        
        T, dual_1T, dual_2T = util.logSinkhorn(gradT - (gamma_T**-1)*torch.log(T), gQ, gR, gamma_T, max_iter = max_inneriters_balanced, \
                         device=device, dtype=dtype, balanced=True, unbalanced=False, \
                                dual_1 = dual_1T, dual_2 = dual_2T)
        
        # Inner latent transition-inverse matrix
        Lambda = torch.diag(1/gQ) @ T @ torch.diag(1/gR)
        
        if printCost:
            if Wasserstein:
                #P = Q @ Lambda @ R.T
                cost = torch.trace(( (Q.T @ C) @ R) @ Lambda.T) #torch.sum(C * P)
            else:
                P = Q @ Lambda @ R.T
                M1 = Q.T @ A**2 @ Q
                M2 = R.T @ B**2 @ R
                cost = one_r.T @ M1 @ one_r + one_r.T @ M2 @ one_r -2*torch.trace((A @ P @ B).T @ P)
                #cost = one_N2.T @ P.T @ A**2 @ P @ one_N2 + one_N1.T @ P @ B**2 @ P.T @ one_N1 - 2*torch.trace((A @ P @ B).T @ P)
                if FGW:
                    cost = (1-alpha)*torch.sum(C * P) + alpha*cost
            errs.append(cost.cpu())
            
        k+=1
    if printCost:
        ''' 
        Plotting OT objective value across iterations.
        '''
        plt.plot(range(len(errs)), errs)
        plt.xlabel('Iterations')
        plt.ylabel('OT-Cost')
        plt.show()
        '''
        Plotting latent coupling.
        '''
        plt.imshow(T.cpu())
        plt.show()
    
    if returnFull:
        P = Q @ Lambda @ R.T
        return P, errs
    else:
        
        if diagonalize_return:
            '''
            Diagonalize return to factorization of Scetbon '21
            '''
            Q = Q @ torch.diag(1 / gQ) @ T
            gR = R.T @ one_N2
            T = torch.diag(gR)
            
        return Q, R, T, errs

'''
----------
Below:
Optimization of FRLC supposing the distance matrices C, A, B have been factorized
----------
'''

def FRLC_LR_opt(
    C_factors, A_factors, B_factors,
    a=None, b=None,
    tau_in=50, tau_out=50, gamma=90,
    r=10, r2=None, max_iter=200,
    device='cpu', dtype=torch.float64,
    printCost=False, returnFull=False, alpha=0.0,
    initialization='Full', init_args=None, full_grad=True,
    convergence_criterion=True, tol=5e-6, min_iter=25,
    max_inneriters_balanced=300, max_inneriters_relaxed=50,
    diagonalize_return=False,
    seed: int | None = None,
    gen: torch.Generator | None = None,
):
    # ---------------- RNG added ----------------
    if gen is None and seed is not None:
        gen_cpu, gen_cuda = make_generators(seed, device)
        gen = gen_cuda if device.startswith('cuda') else gen_cpu
    
    N1, N2 = C_factors[0].size(dim=0), C_factors[1].size(dim=1)
    k = 0
    stationarity_gap = torch.inf

    one_N1 = torch.ones((N1,), device=device, dtype=dtype)
    one_N2 = torch.ones((N2,), device=device, dtype=dtype)

    if a is None: a = one_N1 / N1
    if b is None: b = one_N2 / N2
    if r2 is None: r2 = r

    one_r  = torch.ones((r,),  device=device, dtype=dtype)
    one_r2 = torch.ones((r2,), device=device, dtype=dtype)

    gQ = (1/r)  * one_r
    gR = (1/r2) * one_r2

    if initialization == 'Full':
        full_rank = True
    elif initialization == 'Rank-2':
        full_rank = False
    else:
        full_rank = True
        print('Initialization must be either "Full" or "Rank-2", defaulting to "Full".')

    if init_args is None:
        Q, R, T, Lambda = util.initialize_couplings(
            a, b, gQ, gR, gamma,
            full_rank=full_rank, device=device, dtype=dtype,
            max_iter=max_inneriters_balanced,
            # pass gen through:
            gen=gen
        )
        Q = stabilize_Q_init(Q, device=device, dtype=dtype)  # deterministic
        R = stabilize_Q_init(R, device=device, dtype=dtype)
        T = stabilize_Q_init(T, device=device, dtype=dtype)
        Lambda = torch.diag(1/(Q.T @ one_N1)) @ T @ torch.diag(1/(R.T @ one_N2))
    else:
        Q, R, T = init_args
        if Q is not None: gQ = Q.T @ one_N1
        if R is not None: gR = R.T @ one_N2
        if Q is None or R is None or T is None:
            _Q, _R, _T, Lambda = util.initialize_couplings(
                a, b, gQ, gR, gamma,
                full_rank=full_rank, device=device, dtype=dtype,
                max_iter=max_inneriters_balanced,
                gen=gen
            )
            _Q = stabilize_Q_init(_Q, device=device, dtype=dtype)
            _R = stabilize_Q_init(_R, device=device, dtype=dtype)
            _T = stabilize_Q_init(_T, device=device, dtype=dtype)
            if Q is None: Q = _Q
            if R is None: R = _R
            if T is None: T = _T
        Lambda = torch.diag(1/(Q.T @ one_N1)) @ T @ torch.diag(1/(R.T @ one_N2))

    errs = {'total_cost':[], 'W_cost':[], 'GW_cost': []}
    grad = torch.inf
    gamma_k = gamma
    Q_prev, R_prev, T_prev = None, None, None

    dual_1Q = dual_2Q = None
    dual_1R = dual_2R = None
    dual_1T = dual_2T = None

    while (k < max_iter and (not convergence_criterion or
           (k < min_iter or util.Delta((Q, R, T), (Q_prev, R_prev, T_prev), gamma_k) > tol))):

        if convergence_criterion:
            Q_prev, R_prev, T_prev = Q, R, T

        if k % 25 == 0:
            print(f'Iteration: {k}')

        gradQ, gradR, gamma_k = gd.compute_grad_A_LR(
            C_factors, A_factors, B_factors, Q, R, Lambda, gamma, device,
            alpha=alpha, dtype=dtype, full_grad=full_grad
        )

        R, dual_1R, dual_2R = util.logSinkhorn(
            gradR - (gamma_k**-1)*torch.log(R), b, gR, gamma_k,
            max_iter=max_inneriters_relaxed,
            device=device, dtype=dtype, balanced=False, unbalanced=False,
            tau=tau_in, dual_1=dual_1R, dual_2=dual_2R
        )

        Q, dual_1Q, dual_2Q = util.logSinkhorn(
            gradQ - (gamma_k**-1)*torch.log(Q), a, gQ, gamma_k,
            max_iter=max_inneriters_relaxed,
            device=device, dtype=dtype, balanced=False, unbalanced=False,
            tau=tau_in, dual_1=dual_1Q, dual_2=dual_2Q
        )

        gQ, gR = Q.T @ one_N1, R.T @ one_N2

        gradT, gamma_T = gd.compute_grad_B_LR(
            C_factors, A_factors, B_factors, Q, R, Lambda, gQ, gR, gamma, device,
            alpha=alpha, dtype=dtype
        )

        T, dual_1T, dual_2T = util.logSinkhorn(
            gradT - (gamma_T**-1)*torch.log(T), gQ, gR, gamma_T,
            max_iter=max_inneriters_balanced,
            device=device, dtype=dtype, balanced=True, unbalanced=False,
            dual_1=dual_1T, dual_2=dual_2T
        )

        Lambda = torch.diag(1/gQ) @ T @ torch.diag(1/gR)
        k += 1

        if printCost:
            primal_cost = torch.trace(((Q.T @ C_factors[0]) @ (C_factors[1] @ R)) @ Lambda.T)

            X = R @ ((Lambda.T @ ((Q.T @ A_factors[0]) @ (A_factors[1] @ Q)) @ Lambda) @ (R.T @ B_factors[0])) @ B_factors[1]
            GW_cost = -2 * torch.trace(X)
            del X

            A1_tild, A2_tild = util.hadamard_square_lr(A_factors[0], A_factors[1].T, device=device)
            GW_cost += torch.inner(A1_tild.T @ (Q @ one_r), A2_tild.T @ (Q @ one_r))
            del A1_tild, A2_tild

            B1_tild, B2_tild = util.hadamard_square_lr(B_factors[0], B_factors[1].T, device=device)
            GW_cost += torch.inner(B1_tild.T @ (R @ one_r2), B2_tild.T @ (R @ one_r2))
            del B1_tild, B2_tild

            errs['W_cost'].append(primal_cost.detach().cpu())
            errs['GW_cost'].append(GW_cost.detach().cpu())
            errs['total_cost'].append(((1-alpha)*primal_cost + alpha*GW_cost).detach().cpu())

    if printCost:
        print(f"Initial Wasserstein cost: {errs['W_cost'][0]}, GW-cost: {errs['GW_cost'][0]}, Total cost: {errs['total_cost'][0]}")
        print(f"Final Wasserstein cost: {errs['W_cost'][-1]}, GW-cost: {errs['GW_cost'][-1]}, Total cost: {errs['total_cost'][-1]}")
        plt.plot(errs['total_cost']); plt.show()

    if returnFull:
        P = Q @ Lambda @ R.T
        return P, errs
    else:
        if diagonalize_return:
            Q = Q @ torch.diag(1 / gQ) @ T
            gR = R.T @ one_N2
            T = torch.diag(gR)
        return Q, R, T, errs

def stabilize_Q_init(Q, rand_perturb = False, 
                     lambda_factor = 0.1, max_inneriters_balanced= 300, 
                     device='cpu', dtype=torch.float64):
    # Add a small random or trivial outer product perturbation to ensure stability of one-hot encoded Q
    N2, r2 = Q.shape[0], Q.shape[1]
    b, gQ = torch.sum(Q, axis = 1), torch.sum(Q, axis = 0)
    eps_Q = torch.outer(b, gQ).to(device).type(dtype)
    
    # Yield perturbation, return
    Q_init = ( 1 - lambda_factor ) * Q + lambda_factor * eps_Q
    
    return Q_init

'''
def FRLC_LR_opt(C_factors, 
                A_factors, 
                B_factors, 
                a=None, 
                b=None, 
                tau_in = 50, 
                tau_out=50, 
                gamma=90, 
                r = 10, 
                r2=None, 
                max_iter=200, 
                device='cpu', 
                dtype=torch.float64,
                printCost=False, 
                returnFull=False, 
                alpha=0.0, 
                initialization='Full', 
                init_args = None, 
                full_grad=True, 
                convergence_criterion=True, 
                tol=5e-6, 
                min_iter = 25, 
                max_inneriters_balanced= 300, 
                max_inneriters_relaxed=50, 
                diagonalize_return=False):
    
    N1, N2 = C_factors[0].size(dim=0), C_factors[1].size(dim=1)
    k = 0
    stationarity_gap = torch.inf
    
    one_N1 = torch.ones((N1), device=device, dtype=dtype)
    one_N2 = torch.ones((N2), device=device, dtype=dtype)
    
    if a is None:
        a = one_N1 / N1
    if b is None:
        b = one_N2 / N2
    if r2 is None:
        r2 = r

    one_r = torch.ones((r), device=device, dtype=dtype)
    one_r2 = torch.ones((r2), device=device, dtype=dtype)
    
    # Initialize inner marginals to uniform; 
    # generalized to be of differing dimensions to account for non-square latent-coupling.
    gQ = (1/r)*one_r
    gR = (1/r2)*one_r2
    
    full_rank = True if initialization == 'Full' else False
    
    if initialization == 'Full':
        full_rank = True
    elif initialization == 'Rank-2':
        full_rank = False
    else:
        full_rank = True
        print('Initialization must be either "Full" or "Rank-2", defaulting to "Full".')
        
    if init_args is None:
        Q, R, T, Lambda = util.initialize_couplings(a, b, gQ, gR, \
                                                    gamma, full_rank=full_rank, \
                                                device=device, dtype=dtype, \
                                                    max_iter = max_inneriters_balanced)
        Q, R, T = stabilize_Q_init(Q, device=device, dtype=dtype), stabilize_Q_init(R, device=device, dtype=dtype), stabilize_Q_init(T, device=device, dtype=dtype)
        Lambda = torch.diag(1/ (Q.T @ one_N1)) @ T @ torch.diag(1/ (R.T @ one_N2))
    else:
        # Initialize to given factors
        Q, R, T = init_args
        if Q is not None:
            gQ = Q.T @ one_N1
        if R is not None:
            gR =  R.T @ one_N2
        if Q is None or R is None or T is None:
            _Q, _R, _T, Lambda = util.initialize_couplings(a, b, gQ, gR, \
                                                    gamma, full_rank=full_rank, \
                                                device=device, dtype=dtype, \
                                                    max_iter = max_inneriters_balanced)
            _Q, _R, _T = stabilize_Q_init(_Q, device=device, dtype=dtype), stabilize_Q_init(_R, device=device, dtype=dtype), stabilize_Q_init(_T, device=device, dtype=dtype)
            if Q is None:
                Q = _Q
            if R is None:
                R = _R
            if T is None:
                T = _T
        
        Lambda = torch.diag(1/ (Q.T @ one_N1)) @ T @ torch.diag(1/ (R.T @ one_N2))
    
    errs = {'total_cost':[], 'W_cost':[], 'GW_cost': []}
    grad = torch.inf
    gamma_k = gamma
    Q_prev, R_prev, T_prev = None, None, None

    # Initialize duals for warm-start across iterations
    dual_1Q, dual_2Q = None, None
    dual_1R, dual_2R = None, None
    dual_1T, dual_2T = None, None
    
    while (k < max_iter and (not convergence_criterion or \
                       (k < min_iter or util.Delta((Q, R, T), (Q_prev, R_prev, T_prev), gamma_k) > tol))):
        
        if convergence_criterion:
            # Set previous iterates to evaluate convergence at the next round
            Q_prev, R_prev, T_prev = Q, R, T
        
        if k % 25 == 0:
            print(f'Iteration: {k}')
        
        gradQ, gradR, gamma_k = gd.compute_grad_A_LR(C_factors, A_factors, B_factors, Q, R, Lambda, gamma, device, \
                                   alpha=alpha, dtype=dtype, full_grad=full_grad)
        
        ### Semi-relaxed updates ###
        R, dual_1R, dual_2R = util.logSinkhorn(gradR - (gamma_k**-1)*torch.log(R), b, gR, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=False, tau=tau_in, dual_1 = dual_1R, dual_2 = dual_2R)
        
        Q, dual_1Q, dual_2Q = util.logSinkhorn(gradQ - (gamma_k**-1)*torch.log(Q), a, gQ, gamma_k, max_iter = max_inneriters_relaxed, \
                         device=device, dtype=dtype, balanced=False, unbalanced=False, tau=tau_in, dual_1 = dual_1Q, dual_2 = dual_2Q)
        
        gQ, gR = Q.T @ one_N1, R.T @ one_N2
        
        gradT, gamma_T = gd.compute_grad_B_LR(C_factors, A_factors, B_factors, Q, R, Lambda, gQ, gR, gamma, device, \
                                       alpha=alpha, dtype=dtype)
        
        T, dual_1T, dual_2T = util.logSinkhorn(gradT - (gamma_T**-1)*torch.log(T), gQ, gR, gamma_T, max_iter = max_inneriters_balanced, \
                         device=device, dtype=dtype, balanced=True, unbalanced=False, dual_1 = dual_1T, dual_2 = dual_2T)
        
        # Inner latent transition-inverse matrix
        Lambda = torch.diag(1/gQ) @ T @ torch.diag(1/gR)
        k+=1
        
        if printCost:
            primal_cost = torch.trace(((Q.T @ C_factors[0]) @ (C_factors[1] @ R)) @ Lambda.T)
            X = R @ ((Lambda.T @ ((Q.T @ A_factors[0]) @ (A_factors[1] @ Q)) @ Lambda) @ (R.T @ B_factors[0])) @ B_factors[1]
            GW_cost = - 2 * torch.trace(X) # add these: one_r.T @ M1 @ one_r + one_r.T @ M2 @ one_r
            del X
            A1_tild, A2_tild = util.hadamard_square_lr(A_factors[0], A_factors[1].T, device=device)
            GW_cost += torch.inner(A1_tild.T @ (Q @ one_r), A2_tild.T @ (Q @ one_r))
            del A1_tild, A2_tild
            B1_tild, B2_tild = util.hadamard_square_lr(B_factors[0], B_factors[1].T, device=device)
            GW_cost += torch.inner(B1_tild.T @ (R @ one_r2), B2_tild.T @ (R @ one_r2))
            del B1_tild, B2_tild
            
            errs['W_cost'].append(primal_cost.cpu())
            errs['GW_cost'].append((GW_cost).cpu())
            errs['total_cost'].append(((1-alpha)*primal_cost + alpha*GW_cost).cpu())
        
    if printCost:
        print(f"Initial Wasserstein cost: {errs['W_cost'][0]}, GW-cost: {errs['GW_cost'][0]}, Total cost: {errs['total_cost'][0]}")
        print(f"Final Wasserstein cost: {errs['W_cost'][-1]}, GW-cost: {errs['GW_cost'][-1]}, Total cost: {errs['total_cost'][-1]}")
        plt.plot(errs['total_cost'])
        plt.show()
    
    if returnFull:
        P = Q @ Lambda @ R.T
        return P, errs
    else:
        if diagonalize_return:
            Q = Q @ torch.diag(1 / gQ) @ T
            gR = R.T @ one_N2
            T = torch.diag(gR)
        return Q, R, T, errs

'''
