from torch.optim.optimizer import Optimizer
import torch
import math

def solve_sylvester(A, C):
    # solve Sylvester equation by Bartels–Stewart Algorithm
    R, U = torch.linalg.eig(A)
    r = R.size()[0]
    S = -R
    V = U
    F = torch.linalg.solve(U, (C + 0j) @ V)
    W = R[..., :, None] - S[..., None, :]
    Y = F / W
    K = U[...,:r,:r] @ Y[...,:r,:r] @ torch.linalg.inv(V)[...,:r,:r].to(A.device)
    if all(torch.isreal(x.flatten()[0]) for x in [A, -A, C]):
        K = K.real
    return K

class Lift_Adam(Optimizer):
    def __init__(self, params, lr, weight_decay=0.0, rank=4, reg_k=0.0, Rie_inner=False, betas=(0.9, 0.999), eps=1e-6, correct_bias=True, debug_mode=False, low_memory=True,robust_init=100):
        # weight decay won't work
        defaults = dict(lr=lr, weight_decay=weight_decay, reg_k=reg_k, rank=rank, low_memory=low_memory, Rie_inner=Rie_inner, betas=betas, correct_bias=correct_bias, eps=eps, robust_init=robust_init, debug_mode=debug_mode)
        super().__init__(params, defaults)
        self.rank = rank
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                state = self.state[p1]
                if len(state) == 0 and group["robust_init"] >= 1:
                    state["step"] = 0
                    state = self.state[p2]
                    state["step"] = 0
                if state["step"] >= group["robust_init"] + 1:
                    r = p2.data.shape[1]     # 2n @ r
                    n = int(p2.data.shape[0] / 2) # 2n @ r
                    m = p1.data.shape[1]     # 2r @ m
                    # M 2n@r matrix
                    M = p2.data.clone()
                    # M1 M2 n@r matrix
                    M1, M2 = M[0:n,:], M[n:, :]
                    # N m@2r matrix
                    N = p1.data.clone().t()
                    # N1 N2 m@r matrix
                    N1, N2 = N[:, :r], N[:, r:]
                    # iterate the first pair:
                    state = self.state[p1]
                    # Robust initializations
                    grad_p1 = torch.nan_to_num(p1.grad[:r,:], nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    grad_p2 = torch.nan_to_num(p2.grad[:n,:], nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    M, N = M1, N1
                    A = (M.T@M)@(N.T@N)
                    if len(state) == 0 or state["step"] == group["robust_init"] + 1:
                        #print("Robust Initialization Ends.")
                        #state["step"] = 0
                        if group["low_memory"]:
                            # Low memory inplementation
                            # In case of moving momentum G1, G2 and G3 (Low loss 3.41 but worse generality score 32.2)
                            state["exp_avg_G1"] = torch.zeros(r,r).to(p1.device)
                            state["exp_avg_G2"] = torch.zeros(n,r).to(p1.device)
                            state["exp_avg_G3"] = torch.zeros(r,m).to(p1.device)
                            state["exp_avg_sq"] = torch.zeros(n,m).to(p1.device)
                        else:
                            # normal inplementation
                            state["exp_avg"] = torch.zeros(n,m).to(p1.device)
                            state["exp_avg_sq"] = torch.zeros(n,m).to(p1.device)
                    try:
                        # inv_N = torch.linalg.pinv(N.T@N, hermitian=True).to(p1.device)
                        U_N, R_N = torch.linalg.qr(N) # m*r, r*r
                        R_Np = torch.linalg.pinv(R_N) # r*r
                        inv_N = R_Np @ R_Np.T# r*r
                        # Robust initializations                    
                        if torch.norm(inv_N) > group["eps"] ** -1:
                            inv_N.div_(torch.norm(inv_N)*group["eps"])
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    try:
                        # inv_M = torch.linalg.pinv(M.T@M, hermitian=True).to(p1.device)
                        U_M, R_M = torch.linalg.qr(M) # n*r, r*r
                        R_Mp = torch.linalg.pinv(R_M) # r*r
                        inv_M = R_Mp @ R_Mp.T # r*r
                        # Robust initializations                    
                        if torch.norm(inv_M) > group["eps"] ** -1:
                            inv_M.div_(torch.norm(inv_M)*group["eps"]) 
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    # Gain Rie_grad from Qr and p1 p2 's grad
                    # p1.grad (r*m) = M.T @ grad_W = R_M.T @ U.T @ grad_W, p2.grad (n*r) = grad_W @ N = grad_W @ V @ R_N
                    Rie_M = U_M @ R_Mp.T @ grad_p1
                    Rie_grad = Rie_M + grad_p2 @ R_Np @ U_N.T - Rie_M @ U_N @ U_N.T
                    beta1, beta2 = group["betas"]
                    state["step"] += 1
                    # Preconditioning with Riemannian Adam
                    if group["low_memory"]:
                        exp_avg_G1, exp_avg_G2, exp_avg_G3 = state["exp_avg_G1"], state["exp_avg_G2"], state["exp_avg_G3"]
                        exp_avg_sq = state["exp_avg_sq"]
                        # G2 = nr @ rr - nr @ rn @ nr @ rr
                        G2 = grad_p2 @ R_Np - U_M @ (U_M.T @ (grad_p2 @ R_Np))
                        # G3 = rr @ rm - rr @ rm @ mr @ rm
                        G3 = R_Mp.T @ grad_p1 - R_Mp.T @ grad_p1 @ U_N @ U_N.T
                        # G1 = rn @ nm @ mr
                        G1 = U_M.T @ Rie_grad @ U_N
                        exp_avg_G1.mul_(beta1).add_(G1, alpha=1.0 - beta1)
                        exp_avg_G2.mul_(beta1).add_(G2, alpha=1.0 - beta1)
                        exp_avg_G3.mul_(beta1).add_(G3, alpha=1.0 - beta1)
                        exp_avg = U_M @ exp_avg_G1 @ U_N.T + exp_avg_G2 @ U_N.T + U_M @ exp_avg_G3
                        if group["Rie_inner"]:
                            C = M.T@Rie_grad@N
                            K = solve_sylvester(A, C).clone().detach()
                            lift_Rie_M = Rie_grad @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                            lift_Rie_N = Rie_grad.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                            Rie_inner_sq = torch.trace(inv_M @ lift_Rie_M.T @ lift_Rie_M + inv_N @ lift_Rie_N.T @ lift_Rie_N).sqrt()
                            inner_sq = torch.trace(Rie_grad.T @ Rie_grad).sqrt()
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, (Rie_inner_sq / inner_sq) * Rie_grad, value=1.0 - beta2)
                        else:
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, Rie_grad, value=1.0 - beta2)    
                    else:
                        exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                        # First momentum
                        exp_avg.mul_(beta1).add_(Rie_grad, alpha=1.0 - beta1)
                        # Second momentum
                        # try to apply New inner (+ 0.1)
                        if group["Rie_inner"]:
                            C = M.T@Rie_grad@N
                            K = solve_sylvester(A, C).clone().detach()
                            lift_Rie_M = Rie_grad @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                            lift_Rie_N = Rie_grad.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                            Rie_inner_sq = torch.trace(inv_M @ lift_Rie_M.T @ lift_Rie_M + inv_N @ lift_Rie_N.T @ lift_Rie_N).sqrt()
                            inner_sq = torch.trace(Rie_grad.T @ Rie_grad).sqrt()
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, (Rie_inner_sq / inner_sq) * Rie_grad, value=1.0 - beta2)
                        else:
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, Rie_grad, value=1.0 - beta2)    
                    denom = exp_avg_sq.sqrt().add_(group["eps"])
                    step_size = group["lr"]
                    if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                        bias_correction1 = 1.0 - beta1 ** (state["step"] - group["robust_init"] )
                        bias_correction2 = 1.0 - beta2 ** (state["step"] - group["robust_init"] )
                        step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
                    lifted_vec = exp_avg.div(denom)
                    lifted_vec.mul_(-step_size)
                    C = M.T@lifted_vec@N
                    K = solve_sylvester(A, C)
                    # if torch.norm(A@K+K@A-C) > group["eps"]:
                    #    print("Errors in the Sylvester solution!")
                    if group["reg_k"] > 0.0:
                        K.add_(torch.eye(r).to(p1.device), alpha = group["lr"] * group["reg_k"])
                    # Verify the correctness of lift
                    lift_M1 = lifted_vec @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                    lift_N1 = lifted_vec.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                    # iterate the second pair:
                    state = self.state[p2]
                    # Robust initializations
                    grad_p1 = torch.nan_to_num(p1.grad[r:,:], nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    grad_p2 = torch.nan_to_num(p2.grad[n:,:], nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    M, N = M2, N2
                    A = (M.T@M)@(N.T@N)
                    if len(state) == 0 or state["step"] == group["robust_init"] + 1:
                        #state["step"] = 0
                        if group["low_memory"]:
                            # Low memory inplementation
                            # In case of moving momentum G1, G2 and G3 (Low loss 3.41 but worse generality score 32.2)
                            state["exp_avg_G1"] = torch.zeros(r,r).to(p1.device)
                            state["exp_avg_G2"] = torch.zeros(n,r).to(p1.device)
                            state["exp_avg_G3"] = torch.zeros(r,m).to(p1.device)
                            state["exp_avg_sq"] = torch.zeros(n,m).to(p1.device)
                        else:
                            # normal inplementation
                            state["exp_avg"] = torch.zeros(n,m).to(p1.device)
                            state["exp_avg_sq"] = torch.zeros(n,m).to(p1.device)
                    try:
                        # inv_N = torch.linalg.pinv(N.T@N, hermitian=True).to(p1.device)
                        U_N, R_N = torch.linalg.qr(N) # m*r, r*r
                        R_Np = torch.linalg.pinv(R_N) # r*r
                        inv_N = R_Np @ R_Np.T# r*r
                        # Robust initializations                    
                        if torch.norm(inv_N) > group["eps"] ** -1:
                            inv_N.div_(torch.norm(inv_N)*group["eps"]) 
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    try:
                        # inv_M = torch.linalg.pinv(M.T@M, hermitian=True).to(p1.device)
                        U_M, R_M = torch.linalg.qr(M) # n*r, r*r
                        R_Mp = torch.linalg.pinv(R_M) # r*r
                        inv_M = R_Mp @ R_Mp.T # r*r
                        # Robust initializations                    
                        if torch.norm(inv_M) > group["eps"] ** -1:
                            inv_M.div_(torch.norm(inv_M)*group["eps"]) 
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    # Gain Rie_grad from Qr and p1 p2 's grad
                    # p1.grad (r*m) = M.T @ grad_W = R_M.T @ U.T @ grad_W, p2.grad (n*r) = grad_W @ N = grad_W @ V @ R_N
                    Rie_M = U_M @ R_Mp.T @ grad_p1
                    Rie_grad = Rie_M + grad_p2 @ R_Np @ U_N.T - Rie_M @ U_N @ U_N.T
                    beta1, beta2 = group["betas"]
                    state["step"] += 1
                    # Preconditioning with Riemannian Adam
                    if group["low_memory"]:
                        exp_avg_G1, exp_avg_G2, exp_avg_G3 = state["exp_avg_G1"], state["exp_avg_G2"], state["exp_avg_G3"]
                        exp_avg_sq = state["exp_avg_sq"]
                        # G2 = nr @ rr - nr @ rn @ nr @ rr
                        G2 = grad_p2 @ R_Np - U_M @ (U_M.T @ (grad_p2 @ R_Np))
                        # G3 = rr @ rm - rr @ rm @ mr @ rm
                        G3 = R_Mp.T @ grad_p1 - R_Mp.T @ grad_p1 @ U_N @ U_N.T
                        # G1 = rn @ nm @ mr
                        G1 = U_M.T @ Rie_grad @ U_N
                        exp_avg_G1.mul_(beta1).add_(G1, alpha=1.0 - beta1)
                        exp_avg_G2.mul_(beta1).add_(G2, alpha=1.0 - beta1)
                        exp_avg_G3.mul_(beta1).add_(G3, alpha=1.0 - beta1)
                        exp_avg = U_M @ exp_avg_G1 @ U_N.T + exp_avg_G2 @ U_N.T + U_M @ exp_avg_G3
                        if group["Rie_inner"]:
                            C = M.T@Rie_grad@N
                            K = solve_sylvester(A, C).clone().detach()
                            lift_Rie_M = Rie_grad @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                            lift_Rie_N = Rie_grad.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                            Rie_inner_sq = torch.trace(inv_M @ lift_Rie_M.T @ lift_Rie_M + inv_N @ lift_Rie_N.T @ lift_Rie_N).sqrt()
                            inner_sq = torch.trace(Rie_grad.T @ Rie_grad).sqrt()
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, (Rie_inner_sq / inner_sq) * Rie_grad, value=1.0 - beta2)
                        else:
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, Rie_grad, value=1.0 - beta2)    
                    else:
                        exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                        # First momentum
                        exp_avg.mul_(beta1).add_(Rie_grad, alpha=1.0 - beta1)
                        # Second momentum
                        # try to apply New inner (+ 0.1)
                        if group["Rie_inner"]:
                            C = M.T@Rie_grad@N
                            K = solve_sylvester(A, C).clone().detach()
                            lift_Rie_M = Rie_grad @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                            lift_Rie_N = Rie_grad.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                            Rie_inner_sq = torch.trace(inv_M @ lift_Rie_M.T @ lift_Rie_M + inv_N @ lift_Rie_N.T @ lift_Rie_N).sqrt()
                            inner_sq = torch.trace(Rie_grad.T @ Rie_grad).sqrt()
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, (Rie_inner_sq / inner_sq) * Rie_grad, value=1.0 - beta2)
                        else:
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, Rie_grad, value=1.0 - beta2)    
                    denom = exp_avg_sq.sqrt().add_(group["eps"])
                    step_size = group["lr"]
                    if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                        bias_correction1 = 1.0 - beta1 ** (state["step"] - group["robust_init"])
                        bias_correction2 = 1.0 - beta2 ** (state["step"] - group["robust_init"])
                        step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
                    lifted_vec = exp_avg.div(denom)
                    lifted_vec.mul_(-step_size)
                    C = M.T@lifted_vec@N
                    K = solve_sylvester(A, C)
                    # if torch.norm(A@K+K@A-C) > group["eps"]:
                    #    print("Errors in the Sylvester solution!")
                    if group["reg_k"] > 0.0:
                        K.add_(torch.eye(r).to(p1.device), alpha = group["lr"] * group["reg_k"])
                    # Verify the correctness of lift
                    lift_M2 = lifted_vec @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                    lift_N2 = lifted_vec.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                    # concatanate and update
                    LIFT_M = torch.cat([lift_M1,lift_M2])       # 2n @ r
                    LIFT_N = torch.cat([lift_N1.T,lift_N2.T])   # 2r @ m
                    p1.data.add_(LIFT_N, alpha = 1 )
                    p2.data.add_(LIFT_M, alpha = 1 )
                else:
                    # Robust training
                    for p in (p1,p2):
                        state = self.state[p]
                        r = p2.data.shape[1]     # 2n @ r
                        n = int(p2.data.shape[0] / 2) # 2n @ r
                        m = p1.data.shape[1]     # 2r @ m
                        # M 2n@r matrix
                        M = p2.data.clone()
                        # M1 M2 n@r matrix
                        M1, M2 = M[0:n,:], M[n:, :]
                        # N m@2r matrix
                        N = p1.data.clone().t()
                        # N1 N2 m@r matrix
                        N1, N2 = N[:, :r], N[:, r:]
                        if state["step"] == 0:
                            print("Robust Initialization Start")
                            # state["step"] = 0
                            state["exp_avg_0"] = torch.zeros_like(p.data)
                            state["exp_avg_sq_0"] = torch.zeros_like(p.data)
                        exp_avg, exp_avg_sq = state["exp_avg_0"], state["exp_avg_sq_0"]
                        beta1, beta2 = group["betas"]
                        state["step"] += 1
                        exp_avg.mul_(beta1).add_(p.grad, alpha=1.0 - beta1)
                        exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1.0 - beta2)
                        denom = exp_avg_sq.sqrt().add_(group["eps"])
                        step_size = group["lr"]
                        if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                            bias_correction1 = 1.0 - beta1 ** state["step"]
                            bias_correction2 = 1.0 - beta2 ** state["step"]
                            step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
                        p.data.addcdiv_(-step_size, exp_avg, denom)
                        p.data.add_(p.data, alpha=-group["lr"] * 0.01)

class Lift_RACS(Optimizer):
    def __init__(self, params, lr, weight_decay=0.01, init_q = 1, Rie_inner=True, rank=4, betas=(0.9, 0.98), eps=1e-6, correct_bias=True, reg=0, robust_init=100):
        defaults = dict(lr=lr, weight_decay=weight_decay, init_q = init_q, rank=rank, Rie_inner=Rie_inner, betas=betas, correct_bias=correct_bias, eps=eps, robust_init=robust_init)
        super().__init__(params, defaults)
        self.reg = reg 
        print(f'{self.reg=}')
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            pairs = list(zip(group["params"][0::2], group["params"][1::2]))
            for idx, (p1, p2) in enumerate(pairs):
                #if idx != len(pairs) - 1 and torch.norm(p1.data) > group["eps"] and torch.norm(p2.data) > group["eps"]:
                state = self.state[p1]
                if len(state) == 0 and group["robust_init"] >= 1:
                    state["step"] = 0
                    state = self.state[p2]
                    state["step"] = 0
                if state["step"] >= group["robust_init"] + 1:
                    r = p2.data.shape[1]     # 2n @ r
                    n = int(p2.data.shape[0] / 2) # 2n @ r
                    m = p1.data.shape[1]     # 2r @ m
                    # M 2n@r matrix
                    M = p2.data.clone()
                    # M1 M2 n@r matrix
                    M1, M2 = M[0:n,:], M[n:, :]
                    # N m@2r matrix
                    N = p1.data.clone().t()
                    # N1 N2 m@r matrix
                    N1, N2 = N[:, :r], N[:, r:]
                    dim_1 = n
                    dim_2 = m
                    # iterate the first pair:
                    state = self.state[p1]
                    # Robust initializations
                    grad_p1 = torch.nan_to_num(p1.grad[:r,:], nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    grad_p2 = torch.nan_to_num(p2.grad[:n,:], nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    M, N = M1, N1
                    # if len(state) == 0:
                    if len(state) == 0 or state["step"] == group["robust_init"] + 1:
                        state["step"] = 0
                        state["exp_q"] = torch.zeros(dim_1).to(p1.device)
                        state["exp_s"] = torch.zeros(dim_2).to(p1.device)
                    try:
                        U_N, R_N = torch.linalg.qr(N) # m*r, r*r
                        R_Np = torch.linalg.pinv(R_N) # r*r
                        inv_N = R_Np @ R_Np.T # r*r
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    try:
                        U_M, R_M = torch.linalg.qr(M) # n*r, r*r
                        R_Mp = torch.linalg.pinv(R_M) # r*r
                        inv_M = R_Mp @ R_Mp.T # r*r
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    # Preconditioning with Lift RACS
                    Rie_M = U_M @ R_Mp.T @ grad_p1
                    Rie_grad = Rie_M + grad_p2 @ R_Np @ U_N.T - Rie_M @ U_N @ U_N.T
                    beta1, beta2 = group["betas"]
                    gamma = 1.01
                    scale_alpha = 0.05
                    exp_q, exp_s = state["exp_q"], state["exp_s"]
                    state["step"] += 1
                    # iteration method solving s_t and q_t (5 steps)
                    q = torch.ones(dim_1).to(p1.device) * group["init_q"]
                    for _ in range(5):
                        # GTQ = torch.einsum('ij,j->ij', Rie_grad.T,q) # multiply column s = torch.diag(GTQ@Rie_grad) # Then solve s
                        s = torch.einsum('ij,ji->i', torch.einsum('ij,j->ij', Rie_grad.T,q) , Rie_grad) # diag
                        s.div_(q@q.T)
                        # GS = torch.einsum('ij,j->ij', Rie_grad,s) # multiply column # q = torch.diag(GS@Rie_grad.T) # Then solve q
                        q = torch.einsum('ij,ji->i', torch.einsum('ij,j->ij', Rie_grad,s), Rie_grad.T) # diag
                        q.div_(s@s.T)
                    exp_q.mul_(beta1).add_(q, alpha=1.0 - beta1)
                    exp_s.mul_(beta1).add_(s, alpha=1.0 - beta1)
                    # diag_q = torch.diag(exp_q.sqrt() ** -1) # diag_s = torch.diag(exp_s.sqrt() ** -1) # G_ = diag_q@Rie_grad@diag_s
                    G_ = torch.einsum('i,ij->ij', exp_q.sqrt()**-1, torch.einsum('ij,j->ij', Rie_grad, exp_s.sqrt()**-1) ) # multiply row
                    step_size = group["lr"]
                    lifted_vec = G_
                    lifted_vec.mul_(-step_size*scale_alpha)
                    # Automatic Sylvester decay
                    A = (M.T@M)@(N.T@N)
                    C = M.T@lifted_vec@N
                    # solve Sylvester equation by Bartles-Stewart Algorithm
                    K = solve_sylvester(A, C)
                    # if torch.norm(A@K+K@A-C) > group["eps"]:
                    #     print("Errors in the Sylvester solution!")
                    lift_M1 = lifted_vec @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                    lift_N1 = lifted_vec.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                    if group["Rie_inner"]:
                        norm_G = torch.trace(inv_M @ lift_M1.T @ lift_M1 + inv_N @ lift_N1.T @ lift_N1).sqrt()
                        # norm_G.div_(group["lr"])
                    else:
                        norm_G = torch.trace(lifted_vec.T @ lifted_vec).sqrt()
                        # norm_G.div_(group["lr"])
                    if state["step"] > (group["robust_init"] + 1):
                        eta = gamma / max(gamma, norm_G/state["phi"])
                    else:
                        eta = 1
                    state["phi"] = norm_G * eta
                    lift_M1.mul_(eta)
                    lift_N1.mul_(eta)
                    # iterate the second pair:
                    state = self.state[p2]
                    # Robust initializations
                    grad_p1 = torch.nan_to_num(p1.grad[r:,:], nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    grad_p2 = torch.nan_to_num(p2.grad[n:,:], nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    M, N = M2, N2
                    #if len(state) == 0:
                    if len(state) == 0 or state["step"] == group["robust_init"] + 1:
                        state["step"] = 0
                        state["exp_q"] = torch.zeros(dim_1).to(p1.device)
                        state["exp_s"] = torch.zeros(dim_2).to(p1.device)
                    try:
                        U_N, R_N = torch.linalg.qr(N) # m*r, r*r
                        R_Np = torch.linalg.pinv(R_N) # r*r
                        inv_N = R_Np @ R_Np.T # r*r
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    try:
                        U_M, R_M = torch.linalg.qr(M) # n*r, r*r
                        R_Mp = torch.linalg.pinv(R_M) # r*r
                        inv_M = R_Mp @ R_Mp.T # r*r
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    # Preconditioning with Lift RACS
                    Rie_M = U_M @ R_Mp.T @ grad_p1
                    Rie_grad = Rie_M + grad_p2 @ R_Np @ U_N.T - Rie_M @ U_N @ U_N.T
                    exp_q, exp_s = state["exp_q"], state["exp_s"]
                    state["step"] += 1
                    # iteration method solving s_t and q_t (5 steps)
                    q = torch.ones(dim_1).to(p1.device) * group["init_q"]
                    for _ in range(5):
                        # GTQ = torch.einsum('ij,j->ij', Rie_grad.T,q) # multiply column s = torch.diag(GTQ@Rie_grad) # Then solve s
                        s = torch.einsum('ij,ji->i', torch.einsum('ij,j->ij', Rie_grad.T,q) , Rie_grad) # diag
                        s.div_(q@q.T)
                        # GS = torch.einsum('ij,j->ij', Rie_grad,s) # multiply column # q = torch.diag(GS@Rie_grad.T) # Then solve q
                        q = torch.einsum('ij,ji->i', torch.einsum('ij,j->ij', Rie_grad,s), Rie_grad.T) # diag
                        q.div_(s@s.T)
                    exp_q.mul_(beta1).add_(q, alpha=1.0 - beta1)
                    exp_s.mul_(beta1).add_(s, alpha=1.0 - beta1)
                    # diag_q = torch.diag(exp_q.sqrt() ** -1) # diag_s = torch.diag(exp_s.sqrt() ** -1) # G_ = diag_q@Rie_grad@diag_s
                    G_ = torch.einsum('i,ij->ij', exp_q.sqrt()**-1, torch.einsum('ij,j->ij', Rie_grad, exp_s.sqrt()**-1) ) # multiply row
                    step_size = group["lr"]
                    lifted_vec = G_
                    lifted_vec.mul_(-step_size*scale_alpha)
                    # Automatic Sylvester decay
                    A = (M.T@M)@(N.T@N)
                    C = M.T@lifted_vec@N
                    # solve Sylvester equation by Bartles-Stewart Algorithm
                    K = solve_sylvester(A, C)
                    # if torch.norm(A@K+K@A-C) > group["eps"]:
                    #     print("Errors in the Sylvester solution!")
                    lift_M2 = lifted_vec @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                    lift_N2 = lifted_vec.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                    if group["Rie_inner"]:
                        norm_G = torch.trace(inv_M @ lift_M2.T @ lift_M2 + inv_N @ lift_N2.T @ lift_N2).sqrt()
                        # norm_G.div_(group["lr"])
                    else:
                        norm_G = torch.trace(lifted_vec.T @ lifted_vec).sqrt()
                        # norm_G.div_(group["lr"])
                    if state["step"] > (group["robust_init"] + 1):
                        eta = gamma / max(gamma, norm_G/state["phi"])
                    else:
                        eta = 1
                    state["phi"] = norm_G * eta
                    lift_M2.mul_(eta)
                    lift_N2.mul_(eta)
                    # concatanate and update
                    LIFT_M = torch.cat([lift_M1,lift_M2])       # 2n @ r
                    LIFT_N = torch.cat([lift_N1.T,lift_N2.T])   # 2r @ m
                    p1.data.add_(LIFT_N, alpha = 1 )
                    p2.data.add_(LIFT_M, alpha = 1 )
                else:
                    # if last layer is included, use AdamW
                    for p in (p1,p2):
                        state = self.state[p]
                        r = p2.data.shape[1]     # 2n @ r
                        n = int(p2.data.shape[0] / 2) # 2n @ r
                        m = p1.data.shape[1]     # 2r @ m
                        # M 2n@r matrix
                        M = p2.data.clone()
                        # M1 M2 n@r matrix
                        M1, M2 = M[0:n,:], M[n:, :]
                        # N m@2r matrix
                        N = p1.data.clone().t()
                        # N1 N2 m@r matrix
                        N1, N2 = N[:, :r], N[:, r:]
                        dim_1 = n
                        dim_2 = m
                        #if len(state) == 0:
                        if state["step"] == 0:
                            state["exp_avg"] = torch.zeros_like(p.data)
                            state["exp_avg_sq"] = torch.zeros_like(p.data)
                            state["exp_q"] = torch.zeros(dim_1).to(p1.device)
                            state["exp_s"] = torch.zeros(dim_2).to(p1.device)
                            state["phi"] = 1
                        exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                        beta1, beta2 = group["betas"]
                        state["step"] += 1
                        exp_avg.mul_(beta1).add_(p.grad, alpha=1.0 - beta1)
                        exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1.0 - beta2)
                        denom = exp_avg_sq.sqrt().add_(group["eps"])
                        step_size = group["lr"]
                        if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                            bias_correction1 = 1.0 - beta1 ** state["step"]
                            bias_correction2 = 1.0 - beta2 ** state["step"]
                            step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
                        p.data.addcdiv_(-step_size, exp_avg, denom)
                        if group["weight_decay"] > 0.0:
                            p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
