from torch.optim.optimizer import Optimizer, required
import torch
import math
import copy
import time

def solve_sylvester(A, C):
    # solve Sylvester equation by Bartels–Stewart Algorithm
    R, U = torch.linalg.eig(A)
    r = R.size()[0]
    S = -R
    V = U
    F = torch.linalg.solve(U, (C + 0j) @ V)
    W = R[..., :, None] - S[..., None, :]
    Y = F / W
    K = U[...,:r,:r] @ Y[...,:r,:r] @ torch.linalg.inv(V)[...,:r,:r].to(A.device)
    if all(torch.isreal(x.flatten()[0]) for x in [A, -A, C]):
        K = K.real
    return K

class Lift_Adam(Optimizer):
    def __init__(self, params, lr, weight_decay, reg_k=0.0, Rie_inner=False, betas=(0.9, 0.999), eps=1e-6, correct_bias=True, debug_mode=False, low_memory=True, robust_init=10):
        # weight decay won't work
        defaults = dict(lr=lr, weight_decay=0.0, reg_k=reg_k, low_memory=low_memory, Rie_inner=Rie_inner, betas=betas, correct_bias=correct_bias, eps=eps, debug_mode=debug_mode,robust_init=robust_init)
        super().__init__(params, defaults)
    
    def reset_state(self, p):
        state = self.state[p]
        del state['exp_avg_0']
        del state['exp_avg_sq_0']

    def step(self, closure=None):
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                state = self.state[p2]
                if len(state) == 0 and group["robust_init"] >= 1:
                    # state = self.state[p1]
                    state["step"] = 0
                    state = self.state[p1]
                    state["step"] = 0
                if state["step"] >= group["robust_init"] + 1:
                    if state.get('exp_avg_0') != None:
                        self.reset_state(p2)
                        self.reset_state(p1)
                    # p2=M, p1=N.T
                    # M nxr matrix
                    M = p2.data.clone()
                    # N mxr matrix
                    N = p1.data.clone().t()
                    r = p2.size()[1]
                    n = p2.size()[0]
                    m = p1.size()[1] 
                    # n = p2.size()[0]
                    # m = p1.size()[1]
                    # Robust inits
                    # if torch.norm(M) < group["eps"]:
                    #     M.add_(group["eps"]*torch.eye(n,r).to(p1.device))
                    # if torch.norm(N) < group["eps"]:
                    #     N.add_(group["eps"]*torch.eye(m,r).to(p1.device))
                    # Robust initializations
                    grad_p1 = torch.nan_to_num(p1.grad, nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    grad_p2 = torch.nan_to_num(p2.grad, nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
                    A = (M.T@M)@(N.T@N)
                    if len(state) == 0 or state["step"] == group["robust_init"] + 1:
                        # state["step"] = 0
                        if group["low_memory"]:
                            # Low memory inplementation
                            # In case of moving momentum G1, G2 and G3 (Low loss 3.41 but worse generality score 32.2)
                            state["exp_avg_G1"] = torch.zeros(r,r).to(p1.device)
                            state["exp_avg_G2"] = torch.zeros(n,r).to(p1.device)
                            state["exp_avg_G3"] = torch.zeros(r,m).to(p1.device)
                            state["exp_avg_sq"] = torch.zeros(n,m).to(p1.device)
                        else:
                            # normal inplementation
                            state["exp_avg"] = torch.zeros(n,m).to(p1.device)
                            state["exp_avg_sq"] = torch.zeros(n,m).to(p1.device)
                    try:
                        # inv_N = torch.linalg.pinv(N.T@N, hermitian=True).to(p1.device)
                        U_N, R_N = torch.linalg.qr(N) # m*r, r*r
                        R_Np = torch.linalg.pinv(R_N) # r*r
                        inv_N = R_Np @ R_Np.T# r*r
                        # Robust initializations                    
                        # if torch.norm(inv_N) > group["eps"] ** -1:
                        #     inv_N.div_(torch.norm(inv_N)*group["eps"]) 
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    try:
                        # inv_M = torch.linalg.pinv(M.T@M, hermitian=True).to(p1.device)
                        U_M, R_M = torch.linalg.qr(M) # n*r, r*r
                        R_Mp = torch.linalg.pinv(R_M) # r*r
                        inv_M = R_Mp @ R_Mp.T # r*r
                        # Robust initializations                    
                        # if torch.norm(inv_M) > group["eps"] ** -1:
                        #     inv_M.div_(torch.norm(inv_M)*group["eps"]) 
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    # Gain Rie_grad from Qr and p1 p2 's grad
                    # p1.grad (r*m) = M.T @ grad_W = R_M.T @ U.T @ grad_W, p2.grad (n*r) = grad_W @ N = grad_W @ V @ R_N
                    Rie_M = U_M @ R_Mp.T @ grad_p1
                    Rie_grad = Rie_M + grad_p2 @ R_Np @ U_N.T - Rie_M @ U_N @ U_N.T
                    beta1, beta2 = group["betas"]
                    state["step"] += 1
                    # Preconditioning with Riemannian Adam
                    if group["low_memory"]:
                        exp_avg_G1, exp_avg_G2, exp_avg_G3 = state["exp_avg_G1"], state["exp_avg_G2"], state["exp_avg_G3"]
                        exp_avg_sq = state["exp_avg_sq"]
                        # G2 = nr @ rr - nr @ rn @ nr @ rr
                        G2 = grad_p2 @ R_Np - U_M @ (U_M.T @ (grad_p2 @ R_Np))
                        # G3 = rr @ rm - rr @ rm @ mr @ rm
                        G3 = R_Mp.T @ grad_p1 - R_Mp.T @ grad_p1 @ U_N @ U_N.T
                        # G1 = rn @ nm @ mr
                        G1 = U_M.T @ Rie_grad @ U_N
                        exp_avg_G1.mul_(beta1).add_(G1, alpha=1.0 - beta1)
                        exp_avg_G2.mul_(beta1).add_(G2, alpha=1.0 - beta1)
                        exp_avg_G3.mul_(beta1).add_(G3, alpha=1.0 - beta1)
                        exp_avg = U_M @ exp_avg_G1 @ U_N.T + exp_avg_G2 @ U_N.T + U_M @ exp_avg_G3
                        if group["Rie_inner"]:
                            C = M.T@Rie_grad@N
                            K = solve_sylvester(A, C).clone().detach()
                            lift_Rie_M = Rie_grad @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                            lift_Rie_N = Rie_grad.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                            Rie_inner_sq = torch.trace(inv_M @ lift_Rie_M.T @ lift_Rie_M + inv_N @ lift_Rie_N.T @ lift_Rie_N).sqrt()
                            inner_sq = torch.trace(Rie_grad.T @ Rie_grad).sqrt()
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, (Rie_inner_sq / inner_sq) * Rie_grad, value=1.0 - beta2)
                        else:
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, Rie_grad, value=1.0 - beta2)    
                    else:
                        exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                        # First momentum
                        exp_avg.mul_(beta1).add_(Rie_grad, alpha=1.0 - beta1)
                        # Second momentum
                        # try to apply New inner (+ 0.1)
                        if group["Rie_inner"]:
                            C = M.T@Rie_grad@N
                            K = solve_sylvester(A, C).clone().detach()
                            lift_Rie_M = Rie_grad @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                            lift_Rie_N = Rie_grad.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                            Rie_inner_sq = torch.trace(inv_M @ lift_Rie_M.T @ lift_Rie_M + inv_N @ lift_Rie_N.T @ lift_Rie_N).sqrt()
                            inner_sq = torch.trace(Rie_grad.T @ Rie_grad).sqrt()
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, (Rie_inner_sq / inner_sq) * Rie_grad, value=1.0 - beta2)
                        else:
                            exp_avg_sq.mul_(beta2).addcmul_(Rie_grad, Rie_grad, value=1.0 - beta2)    
                    # Moving exp_avg to same subspace with Rie_grad (-0.4)
                    # exp_avg = exp_avg @ VVt + UUt @ exp_avg - UUt @ exp_avg @ VVt
                    # Normal exp_avg accumulation
                    # exp_avg.mul_(beta1).add_(Rie_grad @ U_N, alpha=1.0 - beta1)
                    # second momentum (keeping tangent) -- (inner product of Rie_grad norm)
                    # exp_avg_sq_pre = exp_avg_sq
                    # exp_avg_sq = exp_avg_sq @ VVt + UUt @ exp_avg_sq - UUt @ exp_avg_sq @ VVt
                    # mass_loss = torch.norm(torch.linalg.svdvals(exp_avg_sq_pre) - torch.linalg.svdvals(exp_avg_sq), float('inf'))
                    
                    # exp_avg_sq.mul_(beta2).addcmul_(Rie_grad  @ U_N, Rie_grad  @ U_N, value=1.0 - beta2)
                    denom = exp_avg_sq.sqrt().add_(group["eps"])
                    step_size = group["lr"]
                    if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                        bias_correction1 = 1.0 - beta1 **  (state["step"] - group["robust_init"] )
                        bias_correction2 = 1.0 - beta2 **  (state["step"] - group["robust_init"] )
                        step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                    lifted_vec = exp_avg.div(denom)
                    # lifted_vec = lifted_vec @ U_N.T
                    lifted_vec.mul_(-step_size)
                    C = M.T@lifted_vec@N
                    K = solve_sylvester(A, C)
                    if torch.norm(A@K+K@A-C) > group["eps"]:
                        print("Errors in the Sylvester solution!")
                    
                    # Using projected lifted vector (+ 0.2)
                    # lifted_vec = lifted_vec @ VVt + UUt @ lifted_vec - UUt @ lifted_vec @ VVt
                    # lifted_vec.add_(I_UUt @ lifted_vec @ I_VVt, alpha = -1)

                    # regularize K (- 0.2)
                    # reg_K = -min(torch.linalg.eigvals(K).real)
                    if group["reg_k"] > 0.0:
                        K.add_(torch.eye(r).to(p1.device), alpha = group["lr"] * group["reg_k"])

                    # Verify the correctness of lift
                    lift_M = lifted_vec @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                    lift_N = lifted_vec.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                    p1.data.add_((lift_N).T, alpha = 1 )
                    p2.data.add_(lift_M, alpha = 1 )
                else:
                    for p in (p1,p2):
                        state = self.state[p]
                        r = p2.data.shape[1]     # 2n @ r
                        n = int(p2.data.shape[0] / 2) # 2n @ r
                        m = p1.data.shape[1]     # 2r @ m
                        # M 2n@r matrix
                        M = p2.data.clone()
                        # M1 M2 n@r matrix
                        M1, M2 = M[0:n,:], M[n:, :]
                        # N m@2r matrix
                        N = p1.data.clone().t()
                        # N1 N2 m@r matrix
                        N1, N2 = N[:, :r], N[:, r:]
                        if state["step"] == 0:
                            print("Robust Initialization Start")
                            # state["step"] = 0
                            state["exp_avg_0"] = torch.zeros_like(p.data)
                            state["exp_avg_sq_0"] = torch.zeros_like(p.data)
                        exp_avg, exp_avg_sq = state["exp_avg_0"], state["exp_avg_sq_0"]
                        beta1, beta2 = group["betas"]
                        state["step"] += 1
                        exp_avg.mul_(beta1).add_(p.grad, alpha=1.0 - beta1)
                        exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1.0 - beta2)
                        denom = exp_avg_sq.sqrt().add_(group["eps"])
                        step_size = group["lr"]
                        if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                            bias_correction1 = 1.0 - beta1 ** state["step"]
                            bias_correction2 = 1.0 - beta2 ** state["step"]
                            step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
                        p.data.addcdiv_(-step_size, exp_avg, denom)
                        p.data.add_(p.data, alpha=-group["lr"] * 0.01)

class Lift_RACS(Optimizer):
    def __init__(self, params, lr, weight_decay, init_q = 1, Rie_inner=True, betas=(0.9, 0.98), eps=1e-6, correct_bias=True, reg=0, robust_init=100):
        defaults = dict(lr=lr, weight_decay=weight_decay, init_q = init_q, Rie_inner=Rie_inner, betas=betas, correct_bias=correct_bias, eps=eps, robust_init=robust_init)
        super().__init__(params, defaults)
        self.reg = reg 
        print(f'{self.reg=}')
    
    def reset_state(self, p):
        state = self.state[p]
        del state['exp_avg_0']
        del state['exp_avg_sq_0']

    def step(self, closure=None):
        for group in self.param_groups:
            #for p1, p2, p3 in list(zip(group["params"],group["params"][1:],group["params"][2:]))[::3]:
            pairs = list(zip(group["params"][0::2], group["params"][1::2]))
            for idx, (p1, p2) in enumerate(pairs):
                state = self.state[p2]
                if len(state) == 0 and group["robust_init"] >= 1:
                    # state = self.state[p1]
                    state["step"] = 0
                    state = self.state[p1]
                    state["step"] = 0
                if state["step"] >= group["robust_init"] + 1:
                    if state.get('exp_avg_0') != None:
                        self.reset_state(p2)
                        self.reset_state(p1)
                    #if idx != len(pairs) - 1 :
                    state = self.state[p2]
                    # p2=M, p1=N.T
                    # M nxr matrix
                    M = p2.data.clone()
                    # N mxr matrix
                    N = p1.data.clone().t()
                    r = p2.size()[1]
                    n = p2.size()[0]
                    m = p1.size()[1]
                    dim_1 = n
                    dim_2 = m
                    if len(state) == 0 or state["step"] == group["robust_init"] + 1:
                        # state["step"] = 0
                        state["exp_q"] = torch.zeros(dim_1).to(p1.device)
                        state["exp_s"] = torch.zeros(dim_2).to(p1.device)
                    try:
                        U_N, R_N = torch.linalg.qr(N) # m*r, r*r
                        R_Np = torch.linalg.pinv(R_N) # r*r
                        inv_N = R_Np @ R_Np.T # r*r
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    try:
                        U_M, R_M = torch.linalg.qr(M) # n*r, r*r
                        R_Mp = torch.linalg.pinv(R_M) # r*r
                        inv_M = R_Mp @ R_Mp.T # r*r
                    except:
                        raise RuntimeError("pinv or qr cannot be computed.")
                    # Preconditioning with Lift RACS
                    Rie_M = U_M @ R_Mp.T @ p1.grad
                    Rie_grad = Rie_M + p2.grad @ R_Np @ U_N.T - Rie_M @ U_N @ U_N.T
                    # Lower-memory variant
                    # A_1, B_1T = U_M @ R_Mp.T, p1.grad - p1.grad @ U_N @ U_N.T
                    # A_2, B_2T = p2.grad @ R_Np, U_N.T
                    beta1, beta2 = group["betas"]
                    gamma = 1.01
                    scale_alpha = 0.05
                    exp_q, exp_s = state["exp_q"], state["exp_s"]
                    # iteration method solving s_t and q_t (5 steps)
                    q = torch.ones(dim_1).to(p1.device) * group["init_q"]
                    for _ in range(5):
                        # GTQ = torch.einsum('ij,j->ij', Rie_grad.T,q) # multiply column s = torch.diag(GTQ@Rie_grad) # Then solve s
                        s = torch.einsum('ij,ji->i', torch.einsum('ij,j->ij', Rie_grad.T,q) , Rie_grad) # diag
                        # s_1 = torch.einsum('ij,ji->i', (B_1T.T, torch.einsum('ij,j->ij', A_1.T,q) @ A_1 @ B_1T + torch.einsum('ij,j->ij', A_1.T,q) @ A_2 @ B_2T))
                        # s_2 = torch.einsum('ij,ji->i', (B_2T.T, torch.einsum('ij,j->ij', A_2.T,q) @ A_1 @ B_1T + torch.einsum('ij,j->ij', A_2.T,q) @ A_2 @ B_2T))
                        # s = s_1 + s_2
                        s.div_(q@q.T)
                        # GS = torch.einsum('ij,j->ij', Rie_grad,s) # multiply column # q = torch.diag(GS@Rie_grad.T) # Then solve q
                        q = torch.einsum('ij,ji->i', torch.einsum('ij,j->ij', Rie_grad,s), Rie_grad.T) # diag
                        # q_1 = torch.einsum('ij,ji->i', (A_1, torch.einsum('ij,j->ij', B_1T, s) @ B_1T.T @ A_1.T + torch.einsum('ij,j->ij', B_1T, s) @ B_2T.T @ A_2.T))
                        # q_2 = torch.einsum('ij,ji->i', (A_2, torch.einsum('ij,j->ij', B_2T, s) @ B_1T.T @ A_1.T + torch.einsum('ij,j->ij', B_2T, s) @ B_2T.T @ A_2.T))
                        # q = q_1 + q_2
                        q.div_(s@s.T)
                    exp_q.mul_(beta1).add_(q, alpha=1.0 - beta1)
                    exp_s.mul_(beta1).add_(s, alpha=1.0 - beta1)
                    # diag_q = torch.diag(exp_q.sqrt() ** -1) # diag_s = torch.diag(exp_s.sqrt() ** -1) # G_ = diag_q@Rie_grad@diag_s
                    # G_ = torch.einsum('i,ij->ij', exp_q.sqrt()**-1, torch.einsum('ij,j->ij', A_1 @ B_1T + A_2 @ B_2T, exp_s.sqrt()**-1) ) # multiply row
                    G_ = torch.einsum('i,ij->ij', exp_q.sqrt()**-1, torch.einsum('ij,j->ij', Rie_grad, exp_s.sqrt()**-1) ) # multiply row
                    step_size = group["lr"]
                    lifted_vec = G_
                    lifted_vec.mul_(-step_size*scale_alpha)
                    # Automatic Sylvester decay
                    A = (M.T@M)@(N.T@N)
                    C = M.T@lifted_vec@N
                    # solve Sylvester equation by Bartles-Stewart Algorithm
                    K = solve_sylvester(A, C)
                    if torch.norm(A@K+K@A-C) > group["eps"]:
                        print("Errors in the Sylvester solution!")
                    lift_M = lifted_vec @ N @ inv_N - M @ (N.T@N) @ K @ inv_N
                    lift_N = lifted_vec.T @ M @ inv_M - N @ (M.T@M) @ K.T @ inv_M
                    if group["Rie_inner"]:
                        norm_G = torch.trace(inv_M @ lift_M.T @ lift_M + inv_N @ lift_N.T @ lift_N).sqrt()
                        # norm_G.div_(group["lr"])
                    else:
                        norm_G = torch.trace(lifted_vec.T @ lifted_vec).sqrt()
                        # norm_G.div_(group["lr"])
                    if state["step"] > (group["robust_init"] + 1):
                        eta = gamma / max(gamma, norm_G/state["phi"])
                    else:
                        eta = 1
                    state["phi"] = norm_G * eta
                    state["step"] += 1
                    p1.data.add_((lift_N).T, alpha = eta)
                    p2.data.add_(lift_M    , alpha = eta)
                    # else:
                        # if last layer is included, use Adam
                    #    for p in (p1,p2):
                    #        state = self.state[p]
                    # if len(state) == 0:
                    #     state["step"] = 0
                    #     state["exp_avg"] = torch.zeros_like(p.data)
                    #     state["exp_avg_sq"] = torch.zeros_like(p.data)
                    # exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                    # beta1, beta2 = group["betas"]
                    # state["step"] += 1
                    # exp_avg.mul_(beta1).add_(p.grad, alpha=1.0 - beta1)
                    # exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1.0 - beta2)
                    # denom = exp_avg_sq.sqrt().add_(group["eps"])
                    # step_size = group["lr"]
                    # if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    #     bias_correction1 = 1.0 - beta1 ** state["step"]
                    #     bias_correction2 = 1.0 - beta2 ** state["step"]
                    #     step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
                    # p.data.addcdiv_(-step_size, exp_avg, denom)
                    # if group["weight_decay"] > 0.0:
                    #     p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
                else:
                    for p in (p1,p2):
                        state = self.state[p]
                        r = p2.data.shape[1]     # 2n @ r
                        n = int(p2.data.shape[0] / 2) # 2n @ r
                        m = p1.data.shape[1]     # 2r @ m
                        # M 2n@r matrix
                        M = p2.data.clone()
                        # M1 M2 n@r matrix
                        M1, M2 = M[0:n,:], M[n:, :]
                        # N m@2r matrix
                        N = p1.data.clone().t()
                        # N1 N2 m@r matrix
                        N1, N2 = N[:, :r], N[:, r:]
                        if state["step"] == 0:
                            print("Robust Initialization Start")
                            # state["step"] = 0
                            state["exp_avg_0"] = torch.zeros_like(p.data)
                            state["exp_avg_sq_0"] = torch.zeros_like(p.data)
                        exp_avg, exp_avg_sq = state["exp_avg_0"], state["exp_avg_sq_0"]
                        beta1, beta2 = group["betas"]
                        state["step"] += 1
                        exp_avg.mul_(beta1).add_(p.grad, alpha=1.0 - beta1)
                        exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1.0 - beta2)
                        denom = exp_avg_sq.sqrt().add_(group["eps"])
                        step_size = group["lr"]
                        if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                            bias_correction1 = 1.0 - beta1 ** state["step"]
                            bias_correction2 = 1.0 - beta2 ** state["step"]
                            step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
                        p.data.addcdiv_(-step_size, exp_avg, denom)
                        p.data.add_(p.data, alpha=-group["lr"] * 0.01)

class RGD_Framework(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=True, reg=0, opt_type="RGD"):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias, reg=reg, opt_type=opt_type)
        super().__init__(params, defaults)
        self.reg = reg
        print(f'{self.reg=}')
    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state["exp_avg"] = torch.zeros_like(p.data)
                state["exp_avg_sq"] = torch.zeros_like(p.data)
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                if p1.grad is None:
                    continue
                # We store all possible state info in the state['p2']
                if hasattr(p2, 'W_new_grad'):
                    # Get the gradient of whole weight by hook.
                    grad_W = p2.W_new_grad.to(p2.data.device).t()
                    grad_W = grad_W.float()
                    del p2.W_new_grad
                    torch.cuda.empty_cache()
                else:
                    raise RuntimeError("Hook is not corrected implemented in forward() method!")
                state = self.state[p2]
                n = p2.size()[0]
                m = p1.size()[1]
                r = p2.size()[1]
                if len(state) == 0:
                    state["step"] = 0
                    state["L0"] = (group["reg"]) * torch.eye(n=n, device=p1.device)
                    state["R0"] = (group["reg"]) * torch.eye(n=m, device=p1.device)
                    state["Lt"] = (group["reg"]) * torch.eye(n=n, device=p1.device)
                    state["Rt"] = (group["reg"]) * torch.eye(n=m, device=p1.device)
                    state["pos_l"] = torch.zeros(n,n, device=p1.device)
                    state["pos_r"] = torch.zeros(m,m, device=p1.device)
                    state["neg_l"] = torch.zeros(n,n, device=p1.device)
                    state["neg_r"] = torch.zeros(m,m, device=p1.device)
                    state["exp_avg"] = torch.zeros_like(p2@p1, device=p1.device)
                    state["U"] = torch.eye(n=n, device=p1.device)[:,:r]
                    state["S"] = None
                    state["Vh"] = torch.eye(n=m, device=p1.device)[:r,:]
                exp_avg = state["exp_avg"]
                beta1, beta2 = group["betas"]
                opt_type = group["opt_type"]
                if opt_type=="RAdaGrad":
                    # 防止溢出而设置的decay系数beta
                    # RAdaGrad 不对梯度进行累加
                    # 跳过步数 默认不跳过
                    skip = 1
                    Lt = state["Lt"]
                    Rt = state["Rt"]
                    # bias correction for RAdaGrad
                    if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                        bias_correction2 = 1.0 - beta2 ** (state["step"] + 1)
                        bias_correction1 = 1.0
                    else:
                        bias_correction2 = 1.0
                        bias_correction1 = 1.0

                    if state["step"] % skip == 0:
                        Lt.mul_(beta2).add_(torch.diag(torch.diag(grad_W.matmul(grad_W.t()), 0)), alpha=(1-beta2))
                        Rt.mul_(beta2).add_(torch.diag(torch.diag(grad_W.t().matmul(grad_W), 0)), alpha=(1-beta2))

                    # 实际计算中需要把累加量修正到 L0 + GG* 的期望量
                    compute_Lt = Lt + state["L0"] * (bias_correction2 - beta2 ** (state["step"] + 1))
                    compute_Rt = Rt + state["R0"] * (bias_correction2 - beta2 ** (state["step"] + 1))
                    
                    state["pos_l"] = torch.diag((torch.diag(compute_Lt, 0) ** (+1/4)))
                    state["neg_l"] = torch.diag(torch.diag(state["pos_l"], 0) ** (-1))
                    state["pos_r"] = torch.diag((torch.diag(compute_Rt, 0) ** (+1/4)))
                    state["neg_r"] = torch.diag(torch.diag(state["pos_r"], 0) ** (-1))
                    
                    # Method 1 : L exp_avg R
                    pre_grad = state["neg_l"]@grad_W@state["neg_r"]
                    step_size = group["lr"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
                elif opt_type=="RAdam":
                    # bias correction for RAdam
                    if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                        bias_correction1 = 1.0 - beta1 ** (state["step"] + 1)
                        bias_correction2 = 1.0 - beta2 ** (state["step"] + 1)
                    else:
                        bias_correction1 = 1.0
                        bias_correction2 = 1.0
                    
                    # 梯度累加量
                    temp = grad_W.detach().clone()
                    exp_avg = state["exp_avg"]
                    exp_avg.mul_(beta1).add_(temp, alpha=1.0 - beta1)
                    #exp_avg.div_(bias_correction1)
                    
                    # 对Lt和Rt进行累加，并且保持量级维持在AdamW相同的水平
                    Lt = state["Lt"]
                    Rt = state["Rt"]

                    if state["step"] % 1 == 0:
                        Lt.mul_(beta2).add_(torch.diag(torch.diag(temp.matmul(temp.t()), 0)), alpha=(1-beta2))
                        Rt.mul_(beta2).add_(torch.diag(torch.diag(temp.t().matmul(temp), 0)), alpha=(1-beta2))
                        
                    # 实际计算中需要把累加量修正到 L0 + GG* 的期望量
                    compute_Lt = Lt + state["L0"] * (bias_correction2 - beta2 ** (state["step"] + 1))
                    compute_Lt.div_(bias_correction2)
                    compute_Rt = Rt + state["R0"] * (bias_correction2 - beta2 ** (state["step"] + 1))
                    compute_Rt.div_(bias_correction2)
                    compute_exp = exp_avg / bias_correction1
                    # 原先版本
                    # compute_Lt = Lt
                    # compute_Rt = Rt
                    
                    state["pos_l"] = torch.diag((torch.diag(compute_Lt, 0) ** (+1/4)))
                    state["neg_l"] = torch.diag(torch.diag(state["pos_l"], 0) ** (-1))
                    state["pos_r"] = torch.diag((torch.diag(compute_Rt, 0) ** (+1/4)))
                    state["neg_r"] = torch.diag(torch.diag(state["pos_r"], 0) ** (-1))
                    pre_grad = state["neg_l"]@compute_exp@state["neg_r"]
                # 无预条件算法对L、R矩阵做单位矩阵初始化，定义4个左右乘矩阵为单位阵，不会更改梯度
                elif opt_type=="Lift RGD":
                    if r!=1:
                        raise ValueError("Lift RGD currently only support rank == 1 !")
                    M = p2.data.clone()
                    N = p2.data.clone()
                    grad_M = p1.grad@torch.inverse(N.T@N)
                    grad_N = p1.grad@torch.inverse(M.T@M)
                else:
                    state["pos_l"] = torch.eye(n=n, device=p1.device)
                    state["neg_l"] = state["pos_l"]
                    state["neg_r"] = torch.eye(n=m, device=p1.device)
                    state["pos_r"] = state["neg_r"]
                state["step"] += 1
                # no_grad() is an important trick to avoid memory leak
                with torch.no_grad():
                    W_data = p2@p1
                    W_data.add_(-pre_grad, alpha=group["lr"])
                    W_data.add_(W_data, alpha=-group["lr"] * group["weight_decay"])
                    del pre_grad
                    # Retraction part
                    M_1 = state["U"][:,0:r].t()@state["pos_l"]@state["U"][:,0:r]
                    M_1_inv = torch.inverse(M_1)
                    y1h_part = M_1_inv@state["U"][:,0:r].t()@state["pos_l"]@W_data
                    y1h = y1h_part@(torch.eye(n=m, device=p1.device) - state["Vh"][0:r,:].t()@state["Vh"][0:r,:])
                    M_2 = state["Vh"][0:r,:]@state["pos_r"]@state["Vh"][0:r,:].t()
                    M_2_inv = torch.inverse(M_2)
                    y2_part = W_data@state["pos_r"]@state["Vh"][0:r,:].t()@M_2_inv
                    y2 = (torch.eye(n=n, device=p1.device) - state["U"][:,0:r]@state["U"][:,0:r].t())@y2_part
                    k0 = y1h_part@state["Vh"][0:r,:].t() + state["U"][:,0:r].t()@y2_part - y1h_part@state["pos_r"]@state["Vh"][0:r,:].t()@M_2_inv
                    # Use 2*qr to replace svd
                    q1,k1 = torch.linalg.qr(y1h.t())
                    q2,k2 = torch.linalg.qr(y2)
                    M = torch.cat((torch.cat((k0,k2),0),torch.cat((k1.t(),torch.zeros(k0.size()[0],k0.size()[0],device=p1.device)),0)),1)
                    #if torch.linalg.cond(M) > 1e10:
                        #raise ValueError("Ill-condition matrix!")
                    u_m, s, vh_m = torch.linalg.svd(M)
                    u = torch.cat((state["U"][:,0:r], q2),1)@u_m
                    vh = vh_m@torch.cat((state["Vh"][0:r,:], q1.t()),0)
                    if state["S"] == None:
                        state["S"] = s
                    else:
                        state["S"][:r].copy_(s[:r])
                    # 使用原地操作
                    state["U"][:,:r].copy_(u[:,:r])
                    state["Vh"][:r,:].copy_(vh[:r, :])
                    p2.data.copy_(u[:,:r]@torch.diag(s[:r]))
                    p1.data.copy_(vh[:r,:])
                    del y1h_part, y2_part, y1h, y2, k0, k1, k2, W_data, u, vh
                    # 在关键位置打印内存使用
                    # print(f"当前内存使用: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
                    # print(f"缓存内存: {torch.cuda.memory_reserved()/1024**3:.2f} GB")
        return loss