from torch.optim.optimizer import Optimizer
import torch
import math
import torch.nn as nn
import numpy as np

class AdamW(Optimizer):
    """ Implements Adam algorithm with weight decay fix.
    Parameters:
        lr (float): learning rate. Default 1e-3.
        betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.98)
        eps (float): Adams epsilon. Default: 1e-6
        weight_decay (float): Weight decay. Default: 0.0
        correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=True):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)


    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state["exp_avg"] = torch.zeros_like(p.data)
                state["exp_avg_sq"] = torch.zeros_like(p.data)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            i = 0
            j = 0
            for p in group["params"]:
                if p.grad is None:
                    continue
                # Skip the parameter lora_S, we don't need to update it
                # Be careful when setting the lora_dim to 1
                if p.data.shape[1] == 1:
                    continue
                i += 1
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                # Add weight decay at the end (fixed version)
                if group["weight_decay"] > 0.0:
                    p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])

        return loss

class SGDr(Optimizer):
    def __init__(self, params, lr, weight_decay, rank=4, reg=1e-6):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        print(f'{self.reg=}')
    def step(self):
        for group in self.param_groups:
            # p1: U, p3: S, p2: Vh, we don't need to update S, just let it always be identity
            for p1, p3, p2 in list(zip(group["params"],group["params"][1:],group["params"][2:]))[::3]:
                dim_1 = p2.data.shape[0]//2
                grad1_0 = p1.grad.data[0:self.rank,:]
                grad1_1 = p1.grad.data[self.rank:,:]
                scale1_0 = p2.data[0:dim_1,:]
                scale1_1 = p2.data[dim_1:,:]
                try:
                    grad1_0_scaled = torch.inverse(scale1_0.T@scale1_0+self.reg*torch.eye(self.rank).to(scale1_0.device))@grad1_0
                except:
                    grad1_0_scaled = grad1_0
                try:
                    grad1_1_scaled = torch.inverse(scale1_1.T@scale1_1+self.reg*torch.eye(self.rank).to(scale1_1.device))@grad1_1
                except:
                    grad1_1_scaled = grad1_1
                grad1_scaled = torch.cat([grad1_0_scaled, grad1_1_scaled])

                grad2_0 = p2.grad.data[0:dim_1,:]
                grad2_1 = p2.grad.data[dim_1:,:]
                scale2_0 = p1.data[0:self.rank,:]
                scale2_1 = p1.data[self.rank:,:]
                try:
                    grad2_0_scaled = grad2_0@torch.inverse(scale2_0@scale2_0.T+self.reg*torch.eye(self.rank).to(scale2_0.device))
                except:
                    grad2_0_scaled = grad2_0
                try:
                    grad2_1_scaled = grad2_1@torch.inverse(scale2_1@scale2_1.T+self.reg*torch.eye(self.rank).to(scale2_1.device))
                except:
                    grad2_1_scaled = grad2_1
                grad2_scaled = torch.cat([grad2_0_scaled, grad2_1_scaled])

                p1.data.add_(grad1_scaled, alpha=-group['lr'])
                p2.data.add_(grad2_scaled, alpha=-group['lr'])

                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])



class AdamWr(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=True, rank=2, reg=1e-6):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        print(f'{self.reg=}')
    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state["exp_avg"] = torch.zeros_like(p.data)
                state["exp_avg_sq"] = torch.zeros_like(p.data)
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            # p1: U, p3: S, p2: Vh, we don't need to update S, just let it always be identity
            for p1, p3, p2 in list(zip(group["params"],group["params"][1:],group["params"][2:]))[::3]:
                state = self.state[p1]
                dim_1 = p2.data.shape[0]//2
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p1.data)
                    state["exp_avg_sq"] = torch.zeros_like(p1.data)
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]
                state["step"] += 1
                grad1_scaled0 = p1.grad.data[0:self.rank,:]
                c0 = p2.data[0:dim_1,:]
                try:
                    c0_ = torch.inverse(c0.T@c0+self.reg*torch.eye(self.rank).to(c0.device))
                except:
                    c0_ = torch.eye((c0.T@c0).shape[0]).to(c0.device)
                grad1_scaled1 = p1.grad.data[self.rank:,:]
                c1 = p2.data[dim_1:,:]
                try:
                    c1_ = torch.inverse(c1.T@c1+self.reg*torch.eye(self.rank).to(c1.device))
                except:
                    c1_ = torch.eye((c1.T@c1).shape[0]).to(c1.device)
                grad1_scaled0 = c0_@grad1_scaled0
                grad1_scaled1 = c1_@grad1_scaled1
                grad1_scaled = torch.cat([grad1_scaled0, grad1_scaled1])
                assert grad1_scaled.shape == p1.grad.data.shape


                exp_avg.mul_(beta1).add_(grad1_scaled, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad1_scaled, grad1_scaled, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                c0 = p1.data[0:self.rank,:]
                c1 = p1.data[self.rank:,:]
                p1.data.addcdiv_(-step_size, exp_avg, denom)
                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])

                
                state = self.state[p2]
                
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p2.data)
                    state["exp_avg_sq"] = torch.zeros_like(p2.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]
                state["step"] += 1

                grad2_scaled0 = p2.grad.data[0:dim_1,:]
                try:
                    c0_ = torch.inverse(c0@c0.T+self.reg*torch.eye(self.rank).to(c0_.device))
                except:
                    c0_ = torch.eye((c0@c0.T).shape[0]).to(c0.device)
                grad2_scaled1 = p2.grad.data[dim_1:,:]
                try:
                    c1_ = torch.inverse(c1@c1.T+self.reg*torch.eye(self.rank).to(c1_.device))
                except:
                    c1_ = torch.eye((c1@c1.T).shape[0]).to(c1.device)
                grad2_scaled0 = grad2_scaled0@c0_
                grad2_scaled1 = grad2_scaled1@c1_
                grad2_scaled = torch.cat([grad2_scaled0,grad2_scaled1])
                assert grad2_scaled.shape == p2.grad.data.shape
                
                exp_avg.mul_(beta1).add_(grad2_scaled, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad2_scaled, grad2_scaled, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p2.data.addcdiv_(-step_size, exp_avg, denom)
                if group["weight_decay"] > 0.0:
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])

        return loss

class plainRGD(Optimizer):
    def __init__(self, params, lr=1e-4, weight_decay=0.0, rank=4, reg=1e-6):
        defaults = dict(lr=lr, weight_decay=weight_decay, rank=rank, reg=reg)
        super(RieGrad, self).__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        print(f'{self.reg=}')

    def compute_inv_PSD(self, mat):
        # compute inverse of a positive semidefinite matrix
        zero_rows = torch.all(mat == 0, dim=1).any()
        if not zero_rows:
            return torch.inverse(mat)
        else:
            print(f"Warning: matrix {mat.shape} is not invertible, return identity matrix")
            return torch.eye(mat.shape[0]).to(mat.device)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2, p3 in list(zip(group["params"], group["params"][1:], group["params"][2:]))[::3]:
                # p1-U: [n x enable_lora, r], p2-S: [r x enable_lora], p3-Vh: [r x enable_lora, m]
                # initialize some variables
                n = p1.data.shape[0]
                dim0 = n // 2
                m = p3.data.shape[1]
                # compute some matrices for later use
                U_0 = p1.data[0:dim0,:]
                U_1 = p1.data[dim0:,:]
                S_0 = p2.data[0:self.rank]
                S_1 = p2.data[self.rank:]
                Vh_0 = p3.data[0:self.rank,:]
                Vh_1 = p3.data[self.rank:,:]

                module = param_to_module.get(id(p1))
                if module is not None and hasattr(module, 'W_new_grad'):
                    G_W = module.W_new_grad.to(p1.data.device)
                    if G_W is not None:
                        G_W = G_W.T[module.lora_ind]
                        G_W_0 = G_W[0:dim0,:]
                        G_W_1 = G_W[dim0:,:]
                    else:
                        print(f"W_new_grad is None for {module}")
                else:
                    print(f"No parent module found for parameters")

                pre_grad_0 = -group["lr"] * G_W_0
                pre_grad_1 = -group["lr"] * G_W_1

                # compute some matrices for later use
                UU_0 = U_0.T @ U_0
                UU_1 = U_1.T @ U_1
                VV_0 = Vh_0 @ Vh_0.T
                VV_1 = Vh_1 @ Vh_1.T
                UU_0_inv = self.compute_inv_PSD(UU_0)
                UU_1_inv = self.compute_inv_PSD(UU_1)
                VV_0_inv = self.compute_inv_PSD(VV_0)
                VV_1_inv = self.compute_inv_PSD(VV_1)
                
                UZ_0 = U_0.T @ pre_grad_0
                UZ_1 = U_1.T @ pre_grad_1
                ZV_0 = pre_grad_0 @ Vh_0.T
                ZV_1 = pre_grad_1 @ Vh_1.T
                
                # compute matrix_K0
                mat_K0_0 = torch.diag(S_0) + UU_0_inv @ UZ_0 @ Vh_0.T + \
                    (U_0.T - UU_0_inv @ U_0.T) @ ZV_0 @ VV_0_inv
                mat_K0_1 = torch.diag(S_1) + UU_1_inv @ UZ_1 @ Vh_1.T + \
                    (U_1.T - UU_1_inv @ U_1.T) @ ZV_1 @ VV_1_inv
                
                # compute matrix_K1
                mat_Y1_0 = UU_0_inv @ UZ_0 @ (torch.eye(m).to(p1.device) - Vh_0.T @ Vh_0)
                mat_Y1_1 = UU_1_inv @ UZ_1 @ (torch.eye(m).to(p1.device) - Vh_1.T @ Vh_1)
                mat_Q1_0, mat_K1_0 = torch.qr(mat_Y1_0.T, some=True)
                mat_Q1_1, mat_K1_1 = torch.qr(mat_Y1_1.T, some=True)
                
                # compute matrix_K2
                mat_Y2_0 = (torch.eye(dim0).to(p1.device) - U_0 @ U_0.T) @ ZV_0 @ VV_0_inv
                mat_Y2_1 = (torch.eye(dim0).to(p1.device) - U_1 @ U_1.T) @ ZV_1 @ VV_1_inv
                mat_Q2_0, mat_K2_0 = torch.qr(mat_Y2_0, some=True)
                mat_Q2_1, mat_K2_1 = torch.qr(mat_Y2_1, some=True)
                
                # SVD of 2r x 2r matrix
                mat_M_0 = torch.zeros(2*self.rank, 2*self.rank).to(p1.device)
                mat_M_0[:self.rank, :self.rank] = mat_K0_0
                mat_M_0[:self.rank, self.rank:] = mat_K1_0.T
                mat_M_0[self.rank:, :self.rank] = mat_K2_0
                U_M_0, S_M_0, Vh_M_0 = torch.svd(mat_M_0, some=True)
                Vh_M_0 = Vh_M_0.T

                mat_M_1 = torch.zeros(2*self.rank, 2*self.rank).to(p1.device)
                mat_M_1[:self.rank, :self.rank] = mat_K0_1
                mat_M_1[:self.rank, self.rank:] = mat_K1_1.T
                mat_M_1[self.rank:, :self.rank] = mat_K2_1
                U_M_1, S_M_1, Vh_M_1 = torch.svd(mat_M_1, some=True)
                Vh_M_1 = Vh_M_1.T

                # update parameters U, S, Vh
                U_0 = torch.cat([U_0, mat_Q2_0], dim=1) @ U_M_0
                U_1 = torch.cat([U_1, mat_Q2_1], dim=1) @ U_M_1
                S_0 = S_M_0
                S_1 = S_M_1
                Vh_0 = Vh_M_0 @ torch.cat([Vh_0, mat_Q1_0.T], dim=0)
                Vh_1 = Vh_M_1 @ torch.cat([Vh_1, mat_Q1_1.T], dim=0)
                
                # Update p1, p2, p3
                U_0 = U_0[:, :self.rank]
                S_0 = S_0[:self.rank]
                Vh_0 = Vh_0[:self.rank, :]

                U_1 = U_1[:, :self.rank]
                S_1 = S_1[:self.rank]
                Vh_1 = Vh_1[:self.rank, :]

                p1.data = torch.cat([U_0, U_1], dim=0)
                p2.data = torch.cat([S_0, S_1], dim=0)
                p3.data = torch.cat([Vh_0, Vh_1], dim=0)

                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p3.data.add_(p3.data, alpha=-group["lr"] * group["weight_decay"])
    
        return loss
    
class RAdaGrad(Optimizer):
    def __init__(self, params, lr=1e-4, betas=(0.9, 0.98), eps=1e-3, weight_decay=0.0, rank=4, reg=1e-6):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, rank=rank, reg=reg)
        super(RieAdaGrad, self).__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        print(f'{self.reg=}')

    def compute_inv_PSD(self, mat):
        # compute inverse of a positive semidefinite matrix
        zero_rows = torch.all(mat == 0, dim=1).any()
        if not zero_rows:
            return torch.inverse(mat)
        else:
            print(f"Warning: matrix {mat.shape} is not invertible, return identity matrix")
            return torch.eye(mat.shape[0]).to(mat.device)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2, p3 in list(zip(group["params"], group["params"][1:], group["params"][2:]))[::3]:
                # p1-U: [n x enable_lora, r], p2-S: [r x enable_lora], p3-Vh: [r x enable_lora, m]
                # initialize some variables
                n = p1.data.shape[0]
                dim0 = n // 2
                m = p3.data.shape[1]
                # compute some matrices for later use
                U_0 = p1.data[0:dim0,:]
                U_1 = p1.data[dim0:,:]
                S_0 = p2.data[0:self.rank]
                S_1 = p2.data[self.rank:]
                Vh_0 = p3.data[0:self.rank,:]
                Vh_1 = p3.data[self.rank:,:]

                module = param_to_module.get(id(p1))
                if module is not None and hasattr(module, 'W_new_grad'):
                    G_W = module.W_new_grad.to(p1.data.device)
                    if G_W is not None:
                        G_W = G_W.T[module.lora_ind]
                        G_W_0 = G_W[0:dim0,:]
                        G_W_1 = G_W[dim0:,:]
                    else:
                        print(f"W_new_grad is None for {module}")
                else:
                    print(f"No parent module found for parameters")

                beta1, beta2 = group["betas"]
                state = self.state[p1]
                if len(state) == 0:
                    state["pre_L_0"] = torch.zeros(dim0).to(p1.device)
                    state["pre_R_0"] = torch.zeros(m).to(p1.device)
                    state["pre_L_1"] = torch.zeros(dim0).to(p1.device)
                    state["pre_R_1"] = torch.zeros(m).to(p1.device)

                pre_L_0 = state["pre_L_0"]
                pre_R_0 = state["pre_R_0"]
                pre_L_1 = state["pre_L_1"]
                pre_R_1 = state["pre_R_1"]
                pre_L_0.mul_(beta2).add_(torch.einsum('ij,ij->i', G_W_0, G_W_0), alpha=1-beta2)
                pre_R_0.mul_(beta2).add_(torch.einsum('ij,ij->j', G_W_0, G_W_0), alpha=1-beta2)
                pre_L_1.mul_(beta2).add_(torch.einsum('ij,ij->i', G_W_1, G_W_1), alpha=1-beta2)
                pre_R_1.mul_(beta2).add_(torch.einsum('ij,ij->j', G_W_1, G_W_1), alpha=1-beta2)

                pre_L_0 = pre_L_0**(0.25)
                pre_L_1 = pre_L_1**(0.25)
                pre_R_0 = pre_R_0**(0.25)
                pre_R_1 = pre_R_1**(0.25)
                pre_L0_inv = torch.diag((pre_L_0 + group["eps"])**(-1.0))
                pre_R0_inv = torch.diag((pre_R_0 + group["eps"])**(-1.0))
                pre_L1_inv = torch.diag((pre_L_1 + group["eps"])**(-1.0))
                pre_R1_inv = torch.diag((pre_R_1 + group["eps"])**(-1.0))

                pre_grad_0 = pre_L0_inv @ G_W_0 @ pre_R0_inv 
                pre_grad_1 = pre_L1_inv @ G_W_1 @ pre_R1_inv
                pre_grad_0 = -group["lr"] * pre_grad_0
                pre_grad_1 = -group["lr"] * pre_grad_1

                # compute some matrices for later use
                ULU_0 = U_0.T @ torch.diag(pre_L_0) @ U_0
                ULU_1 = U_1.T @ torch.diag(pre_L_1) @ U_1
                VRV_0 = Vh_0 @ torch.diag(pre_R_0) @ Vh_0.T
                VRV_1 = Vh_1 @ torch.diag(pre_R_1) @ Vh_1.T
                ULU_0_inv = self.compute_inv_PSD(ULU_0)
                ULU_1_inv = self.compute_inv_PSD(ULU_1)
                VRV_0_inv = self.compute_inv_PSD(VRV_0)
                VRV_1_inv = self.compute_inv_PSD(VRV_1)
                
                ULZ_0 = U_0.T @ torch.diag(pre_L_0) @ pre_grad_0
                ULZ_1 = U_1.T @ torch.diag(pre_L_1) @ pre_grad_1
                ZRV_0 = pre_grad_0 @ torch.diag(pre_R_0) @ Vh_0.T
                ZRV_1 = pre_grad_1 @ torch.diag(pre_R_1) @ Vh_1.T
                
                # compute matrix_K0
                mat_K0_0 = torch.diag(S_0) + ULU_0_inv @ ULZ_0 @ Vh_0.T + \
                    (U_0.T - ULU_0_inv @ U_0.T @ torch.diag(pre_L_0)) @ ZRV_0 @ VRV_0_inv
                mat_K0_1 = torch.diag(S_1) + ULU_1_inv @ ULZ_1 @ Vh_1.T + \
                    (U_1.T - ULU_1_inv @ U_1.T @ torch.diag(pre_L_1)) @ ZRV_1 @ VRV_1_inv
                
                # compute matrix_K1
                mat_Y1_0 = ULU_0_inv @ ULZ_0 @ (torch.eye(m).to(p1.device) - Vh_0.T @ Vh_0)
                mat_Y1_1 = ULU_1_inv @ ULZ_1 @ (torch.eye(m).to(p1.device) - Vh_1.T @ Vh_1)
                mat_Q1_0, mat_K1_0 = torch.qr(mat_Y1_0.T, some=True)
                mat_Q1_1, mat_K1_1 = torch.qr(mat_Y1_1.T, some=True)
                
                # compute matrix_K2
                mat_Y2_0 = (torch.eye(dim0).to(p1.device) - U_0 @ U_0.T) @ ZRV_0 @ VRV_0_inv
                mat_Y2_1 = (torch.eye(dim0).to(p1.device) - U_1 @ U_1.T) @ ZRV_1 @ VRV_1_inv
                mat_Q2_0, mat_K2_0 = torch.qr(mat_Y2_0, some=True)
                mat_Q2_1, mat_K2_1 = torch.qr(mat_Y2_1, some=True)
                
                # SVD of 2r x 2r matrix
                mat_M_0 = torch.zeros(2*self.rank, 2*self.rank).to(p1.device)
                mat_M_0[:self.rank, :self.rank] = mat_K0_0
                mat_M_0[:self.rank, self.rank:] = mat_K1_0.T
                mat_M_0[self.rank:, :self.rank] = mat_K2_0
                U_M_0, S_M_0, Vh_M_0 = torch.svd(mat_M_0, some=True) # Different with torch.linalg.svd
                Vh_M_0 = Vh_M_0.T

                mat_M_1 = torch.zeros(2*self.rank, 2*self.rank).to(p1.device)
                mat_M_1[:self.rank, :self.rank] = mat_K0_1
                mat_M_1[:self.rank, self.rank:] = mat_K1_1.T
                mat_M_1[self.rank:, :self.rank] = mat_K2_1
                U_M_1, S_M_1, Vh_M_1 = torch.svd(mat_M_1, some=True) # Different with torch.linalg.svd
                Vh_M_1 = Vh_M_1.T

                # update parameters U, S, Vh
                U_0 = torch.cat([U_0, mat_Q2_0], dim=1) @ U_M_0
                U_1 = torch.cat([U_1, mat_Q2_1], dim=1) @ U_M_1
                S_0 = S_M_0
                S_1 = S_M_1
                Vh_0 = Vh_M_0 @ torch.cat([Vh_0, mat_Q1_0.T], dim=0)
                Vh_1 = Vh_M_1 @ torch.cat([Vh_1, mat_Q1_1.T], dim=0)
                
                # Update p1, p2, p3
                U_0 = U_0[:, :self.rank]
                S_0 = S_0[:self.rank]
                Vh_0 = Vh_0[:self.rank, :]

                U_1 = U_1[:, :self.rank]
                S_1 = S_1[:self.rank]
                Vh_1 = Vh_1[:self.rank, :]
                
                p1.data = torch.cat([U_0, U_1], dim=0)
                p2.data = torch.cat([S_0, S_1], dim=0)
                p3.data = torch.cat([Vh_0, Vh_1], dim=0)

                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p3.data.add_(p3.data, alpha=-group["lr"] * group["weight_decay"])
    
        return loss
                

class RAdamW(Optimizer):
    def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-3, weight_decay=0.0, correct_bias=True, rank=4, reg=1e-6):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias, rank=rank, reg=reg)
        super(RiemannianAdamW, self).__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        print(f'{self.reg=}')
    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                n, m = p.data.shape
                dim0 = n // 2
                state['step'] = 0
                state["pre_L_0"] = torch.zeros(dim0).to(p.device)
                state["pre_R_0"] = torch.zeros(m).to(p.device)
                state["pre_L_1"] = torch.zeros(dim0).to(p.device)
                state["pre_R_1"] = torch.zeros(m).to(p.device)
                state["exp_avg"] = torch.zeros(n, m).to(p.device)
    
    def compute_inv_PSD(self, mat):
        # compute inverse of a positive semidefinite matrix
        zero_rows = torch.all(mat == 0, dim=1).any()
        if not zero_rows:
            return torch.inverse(mat)
        else:
            print(f"Warning: matrix {mat.shape} is not invertible, return identity matrix")
            return torch.eye(mat.shape[0]).to(mat.device)
    
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2, p3 in list(zip(group["params"], group["params"][1:], group["params"][2:]))[::3]:
                # p1-U: [n x enable_lora, r], p2-S: [r x enable_lora], p3-Vh: [r x enable_lora, m]
                # initialize some variables
                n = p1.data.shape[0]
                dim0 = n // 2
                m = p3.data.shape[1]
                # compute some matrices for later use
                U_0 = p1.data[0:dim0,:]
                U_1 = p1.data[dim0:,:]
                S_0 = p2.data[0:self.rank]
                S_1 = p2.data[self.rank:]
                Vh_0 = p3.data[0:self.rank,:]
                Vh_1 = p3.data[self.rank:,:]

                module = param_to_module.get(id(p1))
                if module is not None and hasattr(module, 'W_new_grad'):
                    G_W = module.W_new_grad.to(p1.data.device)
                    if G_W is not None:
                        G_W = G_W.T[module.lora_ind]
                        G_W_0 = G_W[0:dim0,:]
                        G_W_1 = G_W[dim0:,:]
                    else:
                        print(f"W_new_grad is None for {module}")
                else:
                    print(f"No parent module found for parameters")

                beta1, beta2 = group["betas"]
                state = self.state[p1]
                if len(state) == 0:
                    state["step"] = 0
                    state["pre_L_0"] = torch.zeros(dim0).to(p1.device)
                    state["pre_R_0"] = torch.zeros(m).to(p1.device)
                    state["pre_L_1"] = torch.zeros(dim0).to(p1.device)
                    state["pre_R_1"] = torch.zeros(m).to(p1.device)
                    state["exp_avg"] = torch.zeros(n, m).to(p1.device)
                
                state["step"] += 1
                exp_avg = state["exp_avg"]
                pre_L_0 = state["pre_L_0"]
                pre_R_0 = state["pre_R_0"]
                pre_L_1 = state["pre_L_1"]
                pre_R_1 = state["pre_R_1"]
                exp_avg.mul_(beta1).add_(G_W, alpha=1.0 - beta1)
                pre_L_0.mul_(beta2).add_(torch.einsum('ij,ij->i', G_W_0, G_W_0), alpha=1-beta2)
                pre_R_0.mul_(beta2).add_(torch.einsum('ij,ij->j', G_W_0, G_W_0), alpha=1-beta2)
                pre_L_1.mul_(beta2).add_(torch.einsum('ij,ij->i', G_W_1, G_W_1), alpha=1-beta2)
                pre_R_1.mul_(beta2).add_(torch.einsum('ij,ij->j', G_W_1, G_W_1), alpha=1-beta2)

                pre_L_0 = pre_L_0**(0.25)
                pre_L_1 = pre_L_1**(0.25)
                pre_R_0 = pre_R_0**(0.25)
                pre_R_1 = pre_R_1**(0.25)
                pre_L0_inv = torch.diag((pre_L_0 + group["eps"]) ** (-1.0))
                pre_R0_inv = torch.diag((pre_R_0 + group["eps"]) ** (-1.0))
                pre_L1_inv = torch.diag((pre_L_1 + group["eps"]) ** (-1.0))
                pre_R1_inv = torch.diag((pre_R_1 + group["eps"]) ** (-1.0))

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                pre_grad_0 = pre_L0_inv @ exp_avg[0:dim0,:] @ pre_R0_inv 
                pre_grad_1 = pre_L1_inv @ exp_avg[dim0:,:] @ pre_R1_inv
                pre_grad_0 = -step_size * pre_grad_0
                pre_grad_1 = -step_size * pre_grad_1
                
                # compute some matrices for later use
                ULU_0 = U_0.T @ torch.diag(pre_L_0) @ U_0
                ULU_1 = U_1.T @ torch.diag(pre_L_1) @ U_1
                VRV_0 = Vh_0 @ torch.diag(pre_R_0) @ Vh_0.T
                VRV_1 = Vh_1 @ torch.diag(pre_R_1) @ Vh_1.T
                ULU_0_inv = self.compute_inv_PSD(ULU_0)
                ULU_1_inv = self.compute_inv_PSD(ULU_1)
                VRV_0_inv = self.compute_inv_PSD(VRV_0)
                VRV_1_inv = self.compute_inv_PSD(VRV_1)

                ULZ_0 = U_0.T @ torch.diag(pre_L_0) @ pre_grad_0
                ULZ_1 = U_1.T @ torch.diag(pre_L_1) @ pre_grad_1
                ZRV_0 = pre_grad_0 @ torch.diag(pre_R_0) @ Vh_0.T
                ZRV_1 = pre_grad_1 @ torch.diag(pre_R_1) @ Vh_1.T
                
                # compute matrix_K0
                mat_K0_0 = torch.diag(S_0) + ULU_0_inv @ ULZ_0 @ Vh_0.T + \
                    (U_0.T - ULU_0_inv @ U_0.T @ torch.diag(pre_L_0)) @ ZRV_0 @ VRV_0_inv
                mat_K0_1 = torch.diag(S_1) + ULU_1_inv @ ULZ_1 @ Vh_1.T + \
                    (U_1.T - ULU_1_inv @ U_1.T @ torch.diag(pre_L_1)) @ ZRV_1 @ VRV_1_inv

                # compute matrix_K1
                mat_Y1_0 = ULU_0_inv @ ULZ_0 @ (torch.eye(m).to(p1.device) - Vh_0.T @ Vh_0)
                mat_Y1_1 = ULU_1_inv @ ULZ_1 @ (torch.eye(m).to(p1.device) - Vh_1.T @ Vh_1)
                mat_Q1_0, mat_K1_0 = torch.qr(mat_Y1_0.T, some=True)
                mat_Q1_1, mat_K1_1 = torch.qr(mat_Y1_1.T, some=True)

                # compute matrix_K2
                mat_Y2_0 = (torch.eye(dim0).to(p1.device) - U_0 @ U_0.T) @ ZRV_0 @ VRV_0_inv
                mat_Y2_1 = (torch.eye(dim0).to(p1.device) - U_1 @ U_1.T) @ ZRV_1 @ VRV_1_inv
                mat_Q2_0, mat_K2_0 = torch.qr(mat_Y2_0, some=True)
                mat_Q2_1, mat_K2_1 = torch.qr(mat_Y2_1, some=True)

                # SVD of 2r x 2r matrix
                mat_M_0 = torch.zeros(2*self.rank, 2*self.rank).to(p1.device)
                mat_M_0[:self.rank, :self.rank] = mat_K0_0
                mat_M_0[:self.rank, self.rank:] = mat_K1_0.T
                mat_M_0[self.rank:, :self.rank] = mat_K2_0
                U_M_0, S_M_0, Vh_M_0 = torch.svd(mat_M_0, some=True)
                Vh_M_0 = Vh_M_0.T

                mat_M_1 = torch.zeros(2*self.rank, 2*self.rank).to(p1.device)
                mat_M_1[:self.rank, :self.rank] = mat_K0_1
                mat_M_1[:self.rank, self.rank:] = mat_K1_1.T
                mat_M_1[self.rank:, :self.rank] = mat_K2_1
                U_M_1, S_M_1, Vh_M_1 = torch.svd(mat_M_1, some=True)
                Vh_M_1 = Vh_M_1.T
                # update parameters U, S, Vh
                U_0 = torch.cat([U_0, mat_Q2_0], dim=1) @ U_M_0
                U_1 = torch.cat([U_1, mat_Q2_1], dim=1) @ U_M_1
                S_0 = S_M_0
                S_1 = S_M_1
                Vh_0 = Vh_M_0 @ torch.cat([Vh_0, mat_Q1_0.T], dim=0)
                Vh_1 = Vh_M_1 @ torch.cat([Vh_1, mat_Q1_1.T], dim=0)
                # truncate U, S, Vh to rank
                U_0 = U_0[:, :self.rank]
                S_0 = S_0[:self.rank]
                Vh_0 = Vh_0[:self.rank, :]

                U_1 = U_1[:, :self.rank]
                S_1 = S_1[:self.rank]
                Vh_1 = Vh_1[:self.rank, :]

                p1.data = torch.cat([U_0, U_1], dim=0)
                p2.data = torch.cat([S_0, S_1], dim=0)
                p3.data = torch.cat([Vh_0, Vh_1], dim=0)

                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p3.data.add_(p3.data, alpha=-group["lr"] * group["weight_decay"])
    
        return loss
    