from torch.optim.optimizer import Optimizer
import torch
import math
import gc
import sys
import torch.nn as nn
import numpy as np


class AdamW(Optimizer):
    """ Implements Adam algorithm with weight decay fix.
    Parameters:
        lr (float): learning rate. Default 1e-3.
        betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.98)
        eps (float): Adams epsilon. Default: 1e-6
        weight_decay (float): Weight decay. Default: 0.0
        correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=True):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)


    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state["exp_avg"] = torch.zeros_like(p.data)
                state["exp_avg_sq"] = torch.zeros_like(p.data)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            i = 0
            j = 0
            for p in group["params"]:
                if p.grad is None:
                    continue
                i += 1
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                # Add weight decay at the end (fixed version)
                if group["weight_decay"] > 0.0:
                    p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])

        return loss
    
class AdamWr(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=True, rank=2, reg=1e-6):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        print(f'{self.reg=}')
    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state["exp_avg"] = torch.zeros_like(p.data)
                state["exp_avg_sq"] = torch.zeros_like(p.data)
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                state = self.state[p1]
                dim_1 = p2.data.shape[0]//2
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p1.data)
                    state["exp_avg_sq"] = torch.zeros_like(p1.data)
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]
                state["step"] += 1
                grad1_scaled0 = p1.grad.data[0:self.rank,:]
                c0 = p2.data[0:dim_1,:]
                try:
                    c0_ = torch.inverse(c0.T@c0+self.reg*torch.eye(self.rank).to(c0.device))
                except:
                    c0_ = torch.eye((c0.T@c0).shape[0]).to(c0.device)
                grad1_scaled1 = p1.grad.data[self.rank:,:]  
                c1 = p2.data[dim_1:,:]
                try:
                    c1_ = torch.inverse(c1.T@c1+self.reg*torch.eye(self.rank).to(c1.device))
                except:
                    c1_ = torch.eye((c1.T@c1).shape[0]).to(c1.device)
                grad1_scaled0 = c0_@grad1_scaled0
                grad1_scaled1 = c1_@grad1_scaled1
                grad1_scaled = torch.cat([grad1_scaled0, grad1_scaled1])
                assert grad1_scaled.shape == p1.grad.data.shape

                exp_avg.mul_(beta1).add_(grad1_scaled, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad1_scaled, grad1_scaled, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                c0 = p1.data[0:self.rank,:]
                c1 = p1.data[self.rank:,:]
                p1.data.addcdiv_(-step_size, exp_avg, denom)
                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])

                
                state = self.state[p2]
                
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p2.data)
                    state["exp_avg_sq"] = torch.zeros_like(p2.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]
                state["step"] += 1

                grad2_scaled0 = p2.grad.data[0:dim_1,:]
                try:
                    c0_ = torch.inverse(c0@c0.T+self.reg*torch.eye(self.rank).to(c0_.device))
                except:
                    c0_ = torch.eye((c0@c0.T).shape[0]).to(c0.device)
                grad2_scaled1 = p2.grad.data[dim_1:,:]
                try:
                    c1_ = torch.inverse(c1@c1.T+self.reg*torch.eye(self.rank).to(c1_.device))
                except:
                    c1_ = torch.eye((c1@c1.T).shape[0]).to(c1.device)
                grad2_scaled0 = grad2_scaled0@c0_
                grad2_scaled1 = grad2_scaled1@c1_
                grad2_scaled = torch.cat([grad2_scaled0,grad2_scaled1])
                assert grad2_scaled.shape == p2.grad.data.shape
                # print(f'{grad2_scaled.shape=}')

                exp_avg.mul_(beta1).add_(grad2_scaled, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad2_scaled, grad2_scaled, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p2.data.addcdiv_(-step_size, exp_avg, denom)
                if group["weight_decay"] > 0.0:
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])

        return loss
    
class LoRAPro_AdamW(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True, rank=4, reg=1e-8, alpha=32):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        self.eps = eps
        self.scaling = alpha / rank
        print(f"self.reg = {self.reg}")
        print(f"scaling = {self.scaling}")
    
    def compute_eig(self, matrix):
        """
        compute the eigenvalue and eigenvectors, process the complex numbers
        
        para:
            matrix
            
        return:
            eigenvalues: complex eigenvalues
            eigenvectors: processed eigenvectors
        """
        # compute
        eigenvalues, eigenvectors = torch.eig(matrix, eigenvectors=True)
        
        # process
        complex_eigenvalues = torch.complex(eigenvalues[:, 0], eigenvalues[:, 1]).to(matrix.device)
        
        n = matrix.size(0)
        complex_eigenvectors = torch.zeros(n, n, dtype=torch.complex64).to(matrix.device)
        
        for i in range(n):
            if eigenvalues[i, 1] == 0:
                complex_eigenvectors[:, i] = torch.complex(eigenvectors[:, i], torch.zeros(n).to(matrix.device)).to(matrix.device)
            else: 
                
                if i < n-1 and torch.allclose(eigenvalues[i], torch.tensor([eigenvalues[i+1, 0], -eigenvalues[i+1, 1]])):
                    
                    complex_eigenvectors[:, i] = torch.complex(eigenvectors[:, i], eigenvectors[:, i+1]).to(matrix.device)
                    complex_eigenvectors[:, i+1] = torch.complex(eigenvectors[:, i], -eigenvectors[:, i+1]).to(matrix.device)
                    i += 1  
                else:
                    
                    complex_eigenvectors[:, i] = torch.complex(eigenvectors[:, i], torch.zeros(n).to(matrix.device)).to(matrix.device)
        
        return complex_eigenvalues, complex_eigenvectors

    def solve_sylvester(self, A, B, C, X=None):
        ''' 
        solve X: AX+XB=C
        From the answer here: 
        https://stackoverflow.com/questions/73713072/solving-sylvester-equations-in-pytorch

        For torch version >= 1.9, use torch.linglg.eig, torch.linalg.solve
        '''
        B = -B
        m = B.shape[-1]
        n = A.shape[-1]
        try:
            R, U = np.linalg.eig(A.cpu().numpy())
        except:
            print(A)
            R, U = np.linalg.eig((A + 1e-6 * torch.eye(A.shape[0])).cpu().numpy())

        S, V = np.linalg.eig(B.cpu().numpy())

        F = np.linalg.solve(U, (C.cpu()+0j) @ V)
        W = R[..., :, None] - S[..., None, :]
        Y = F / W
        X = U[...,:n,:n] @ Y[...,:n,:m] @ np.linalg.inv(V)[...,:m,:m]
        return torch.from_numpy(X.real).to(A.device) if all(torch.isreal(x.flatten()[0]) 
                for x in [A, B, C]) else torch.from_numpy(X).to(A.device)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2 in list(zip(group['params'], group['params'][1:]))[::2]:
                # B/p2:m x r, A/p1: r x n
                # parameters
                dim_1 = p2.data.shape[0]//2
                MatrixA_0 = p1.data[0:self.rank,:]
                MatrixA_1 = p1.data[self.rank:,:]
                MatrixB_0 = p2.data[0:dim_1,:]
                MatrixB_1 = p2.data[dim_1:,:]
                grad_loraA_0 = p1.grad.data[0:self.rank,:]
                grad_loraA_1 = p1.grad.data[self.rank:,:]
                grad_loraB_0 = p2.grad.data[0:dim_1,:]
                grad_loraB_1 = p2.grad.data[dim_1:,:]

                beta1, beta2 = group["betas"]

                state = self.state[p1]
                if len(state) == 0:
                    state["step"] = 0
                    
                state["step"] += 1
                # compute (B^TB)^-1 and (AA^T)^-1
                if state["step"] == 1:
                    AA_T_0_inv = torch.inverse(MatrixA_0 @ MatrixA_0.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                    AA_T_1_inv = torch.inverse(MatrixA_1 @ MatrixA_1.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                    gradA_0 = grad_loraA_0
                    gradA_1 = grad_loraA_1
                    gradB_0 = (1 / self.scaling ** 2) * grad_loraB_0 @ AA_T_0_inv
                    gradB_1 = (1 / self.scaling ** 2) * grad_loraB_1 @ AA_T_1_inv

                else:
                    B_TB_0_inv = torch.inverse(MatrixB_0.T @ MatrixB_0 + self.reg * torch.eye(self.rank).to(p1.data.device))
                    B_TB_1_inv = torch.inverse(MatrixB_1.T @ MatrixB_1 + self.reg * torch.eye(self.rank).to(p1.data.device))
                    AA_T_0_inv = torch.inverse(MatrixA_0 @ MatrixA_0.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                    AA_T_1_inv = torch.inverse(MatrixA_1 @ MatrixA_1.T + self.reg * torch.eye(self.rank).to(p1.data.device))

                    # compute gradA = (B^TB)^-1 @ grad_loraA, gradB = [I-B(B^TB)^-1 @ B^T] @ grad_loraB @ (AA^T)^-1
                    gradA_0 = (1 / self.scaling ** 2) * B_TB_0_inv @ grad_loraA_0 
                    gradA_1 = (1 / self.scaling ** 2) * B_TB_1_inv @ grad_loraA_1

                    gradB_0 = (1 / self.scaling ** 2) * grad_loraB_0 @ AA_T_0_inv - (1 / self.scaling ** 2) * MatrixB_0 @ B_TB_0_inv @ MatrixB_0.T @ grad_loraB_0 @ AA_T_0_inv
                    gradB_1 = (1 / self.scaling ** 2) * grad_loraB_1 @ AA_T_1_inv - (1 / self.scaling ** 2) * MatrixB_1 @ B_TB_1_inv @ MatrixB_1.T @ grad_loraB_1 @ AA_T_1_inv
                
                # compute gradW = gradB @ A + B @ gradA
                gradW_0 = self.scaling * (gradB_0 @ MatrixA_0 + MatrixB_0 @ gradA_0)
                gradW_1 = self.scaling * (gradB_1 @ MatrixA_1 + MatrixB_1 @ gradA_1)
                gradW = torch.cat([gradW_0, gradW_1])

                # initialize
                if state["step"] == 1:
                    state["exp_avg"] = (1 - beta1) * gradW
                    state["exp_avg_sq"] = (1 - beta2) * gradW * gradW
                    exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                else:
                    exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                    exp_avg.mul_(beta1).add_(gradW, alpha=1.0 - beta1)
                    exp_avg_sq.mul_(beta2).addcmul_(gradW, gradW, value=1.0 - beta2)
                
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]

                # m_t / (1-beta1^t) / (sqrt[v_t / (1-beta2^t)] + eps)
                denom = torch.div(exp_avg_sq, bias_correction2).sqrt().add_(group["eps"])
                gradW_tilde = torch.div(exp_avg, bias_correction1 * denom)

                # calculate gradA and gradB again
                # gradA_tilde = B.T @ gradW_tilde, gradB_tilde = gradW_tilde @ A.T
                gradA_tilde_0 = self.scaling * MatrixB_0.T @ gradW_tilde[0:dim_1, :]
                gradA_tilde_1 = self.scaling * MatrixB_1.T @ gradW_tilde[dim_1:, :]

                gradB_tilde_0 = self.scaling * gradW_tilde[0:dim_1, :] @ MatrixA_0.T
                gradB_tilde_1 = self.scaling * gradW_tilde[dim_1:, :] @ MatrixA_1.T

                # solve Sylvester equation
                if state["step"] == 1:
                    delta_A_0 = gradA_tilde_0
                    delta_A_1 = gradA_tilde_1
                    delta_B_0 = (1 / self.scaling ** 2) * gradB_tilde_0 @  AA_T_0_inv
                    delta_B_1 = (1 / self.scaling ** 2) * gradB_tilde_1 @  AA_T_1_inv

                    delta_A = torch.cat([delta_A_0, delta_A_1])
                    assert delta_A.shape == p1.grad.data.shape

                    delta_B = torch.cat([delta_B_0, delta_B_1])
                    assert delta_B.shape == p2.grad.data.shape
                else:
                    MatrixX_0 = self.solve_sylvester(MatrixB_0.T @ MatrixB_0, MatrixA_0 @ MatrixA_0.T, - (1 / self.scaling ** 2) * B_TB_0_inv @ gradA_tilde_0 @ MatrixA_0.T)
                    MatrixX_1 = self.solve_sylvester(MatrixB_1.T @ MatrixB_1, MatrixA_1 @ MatrixA_1.T, - (1 / self.scaling ** 2) * B_TB_1_inv @ gradA_tilde_1 @ MatrixA_1.T)

                    # Update A/p1
                    # delta_A = (B^TB)^-1 @ gradA_tilde + XA
                    delta_A_0 = (1 / self.scaling ** 2) * B_TB_0_inv @ gradA_tilde_0 + MatrixX_0 @ MatrixA_0
                    delta_A_1 = (1 / self.scaling ** 2) * B_TB_1_inv @ gradA_tilde_1 + MatrixX_1 @ MatrixA_1

                    delta_A = torch.cat([delta_A_0, delta_A_1])
                    assert delta_A.shape == p1.grad.data.shape

                    # Update B/p2
                    # delta_B = [I - B(B^T B)^-1 B^T] gradB_tilde (AA^T)^-1 - BX
                    delta_B_0 = (1 / self.scaling ** 2) * gradB_tilde_0 @ AA_T_0_inv - \
                        (1 / self.scaling ** 2) * MatrixB_0 @ B_TB_0_inv @ MatrixB_0.T @ gradB_tilde_0 @ AA_T_0_inv - MatrixB_0 @ MatrixX_0
                    delta_B_1 = (1 / self.scaling ** 2) * gradB_tilde_1 @ AA_T_1_inv - \
                        (1 / self.scaling ** 2) * MatrixB_1 @ B_TB_1_inv @ MatrixB_1.T @ gradB_tilde_1 @ AA_T_1_inv - MatrixB_1 @ MatrixX_1

                    delta_B = torch.cat([delta_B_0, delta_B_1])
                    assert delta_B.shape == p2.grad.data.shape

                
                if group["weight_decay"] > 0.0:
                    p1.data = p1.data * math.sqrt(1-group["lr"] * group["weight_decay"])

                p1.data.add_(delta_A, alpha=-group["lr"])

                

                if group["weight_decay"] > 0.0:
                    p2.data = p2.data * math.sqrt(1-group["lr"] * group["weight_decay"])

                p2.data.add_(delta_B, alpha=-group["lr"])

        return loss
    
class SoLoRA(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=True, rank=4, reg=1e-6):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2 in list(zip(group['params'], group['params'][1:]))[::2]:
                # parameters
                dim_1 = p2.data.shape[0]//2
                MatrixA_0 = p1.data[0:self.rank,:]
                MatrixA_1 = p1.data[self.rank:,:]
                MatrixB_0 = p2.data[0:dim_1,:]
                MatrixB_1 = p2.data[dim_1:,:]

                G_W = None

                # hook
                module = param_to_module.get(id(p1))
                if module is not None and hasattr(module, 'W_new_grad'):
                    G_W = module.W_new_grad.to(p1.data.device)
                    if G_W is not None:
                        # print(f'{module.lora_ind=}')
                        # print(f'{G_W.shape=}')
                        G_W = G_W.T[module.lora_ind]
                        # print(f'{G_W.shape=}')
                        G_W_0 = G_W[0:dim_1,:]
                        G_W_1 = G_W[dim_1:,:]
                    else:
                        print(f"W_new_grad is None for {module}")
                else:
                    print(f"No parent module found for parameters")

                state = self.state[p1]
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(G_W)

                    state["pre_L_0"] = torch.zeros(dim_1).to(p1.data.device)
                    state["pre_L_1"] = torch.zeros(dim_1).to(p1.data.device)
                    state["pre_R_0"] = torch.zeros(MatrixA_0.shape[1]).to(p1.data.device)
                    state["pre_R_1"] = torch.zeros(MatrixA_1.shape[1]).to(p1.data.device)

                exp_avg = state["exp_avg"]
                beta1, beta2 = group["betas"]
                state["step"] += 1
                
                pre_L_0 = state["pre_L_0"]
                pre_L_1 = state["pre_L_1"]
                pre_R_0 = state["pre_R_0"]
                pre_R_1 = state["pre_R_1"]

                exp_avg.mul_(beta1).add_(G_W, alpha=1.0 - beta1)

                pre_L_0.mul_(beta2).add_(torch.einsum('ij,ij->i', G_W_0, G_W_0), alpha=1-beta2)
                pre_L_1.mul_(beta2).add_(torch.einsum('ij,ij->i', G_W_1, G_W_1), alpha=1-beta2)
                pre_R_0.mul_(beta2).add_(torch.einsum('ij,ij->j', G_W_0, G_W_0), alpha=1-beta2)
                pre_R_1.mul_(beta2).add_(torch.einsum('ij,ij->j', G_W_1, G_W_1), alpha=1-beta2)

                traceL_0 = torch.sum(pre_L_0)
                traceL_1 = torch.sum(pre_L_1)

                try:
                    B_TLB_0_inv = torch.inverse(MatrixB_0.T@torch.diag(pre_L_0**0.5)@MatrixB_0 + 
                                                self.reg * torch.eye(self.rank).to(p1.data.device))
                except:
                    B_TLB_0_inv = torch.diag((pre_L_0**(-0.5))[0:self.rank])
                try:
                    B_TLB_1_inv = torch.inverse(MatrixB_1.T@torch.diag(pre_L_1**0.5)@MatrixB_1 + 
                                                self.reg * torch.eye(self.rank).to(p1.data.device))
                except:
                    B_TLB_1_inv = torch.diag((pre_L_1**(-0.5))[0:self.rank])

                try:
                    ARA_T_0_inv = torch.inverse(MatrixA_0@torch.diag(pre_R_0**0.5)@MatrixA_0.T + 
                                                self.reg * torch.eye(self.rank).to(p1.data.device))
                except:
                    ARA_T_0_inv = torch.diag((pre_R_0**(-0.5))[0:self.rank])
                try:
                    ARA_T_1_inv = torch.inverse(MatrixA_1@torch.diag(pre_R_1**0.5)@MatrixA_1.T + 
                                                self.reg * torch.eye(self.rank).to(p1.data.device))
                except:
                    ARA_T_1_inv = torch.diag((pre_R_1**(-0.5))[0:self.rank])

                # Update A/p1
                delta_A_part_0 = B_TLB_0_inv @ MatrixB_0.T @ exp_avg[0:dim_1,:] * (traceL_0**0.5)
                delta_A_part_1 = B_TLB_1_inv @ MatrixB_1.T @ exp_avg[dim_1:,:] * (traceL_1**0.5)

                # delta_A_part * pre_R^{-0.25} - 0.5*delta_A_part*A^T*(A*pre_R^{0.25}*A^T)^-1*A
                delta_A_0 = delta_A_part_0 @ torch.diag(pre_R_0**(-0.5)) - 0.5*delta_A_part_0@MatrixA_0.T@ARA_T_0_inv@MatrixA_0
                delta_A_1 = delta_A_part_1 @ torch.diag(pre_R_1**(-0.5)) - 0.5*delta_A_part_1@MatrixA_1.T@ARA_T_1_inv@MatrixA_1

                delta_A = torch.cat([delta_A_0, delta_A_1])
                assert delta_A.shape == p1.grad.data.shape

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p1.data.add_(delta_A, alpha=-step_size)
                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])

                # Update B/p2
                delta_B_part_0 = exp_avg[0:dim_1,:] @ MatrixA_0.T @ ARA_T_0_inv * (traceL_0**0.5)
                delta_B_part_1 = exp_avg[dim_1:,:] @ MatrixA_1.T @ ARA_T_1_inv * (traceL_1**0.5)

                # pre_L^{-0.25} * delta_B_part - 0.5 * B * (B^T*pre_L^{0.25}*B)^{-1} * B^T * delta_B_part
                delta_B_0 = torch.diag(pre_L_0**(-0.5)) @ delta_B_part_0 - 0.5*MatrixB_0@B_TLB_0_inv@MatrixB_0.T@delta_B_part_0
                delta_B_1 = torch.diag(pre_L_1**(-0.5)) @ delta_B_part_1 - 0.5*MatrixB_1@B_TLB_1_inv@MatrixB_1.T@delta_B_part_1

                delta_B = torch.cat([delta_B_0, delta_B_1])
                assert delta_B.shape == p2.grad.data.shape

                p2.data.add_(delta_B, alpha=-step_size)
                if group["weight_decay"] > 0.0:
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])

        return loss

class SGDr(Optimizer):
    def __init__(self, params, lr, weight_decay, rank=4, reg=1e-6):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        print(f'{self.reg=}')
    def step(self):
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                dim_1 = p2.data.shape[0]//2
                grad1_0 = p1.grad.data[0:self.rank,:]
                grad1_1 = p1.grad.data[self.rank:,:]
                scale1_0 = p2.data[0:dim_1,:]
                scale1_1 = p2.data[dim_1:,:]
                try:
                    grad1_0_scaled = torch.inverse(scale1_0.T@scale1_0+self.reg*torch.eye(self.rank).to(scale1_0.device))@grad1_0
                except:
                    grad1_0_scaled = grad1_0
                try:
                    grad1_1_scaled = torch.inverse(scale1_1.T@scale1_1+self.reg*torch.eye(self.rank).to(scale1_1.device))@grad1_1
                except:
                    grad1_1_scaled = grad1_1
                grad1_scaled = torch.cat([grad1_0_scaled, grad1_1_scaled])

                grad2_0 = p2.grad.data[0:dim_1,:]
                grad2_1 = p2.grad.data[dim_1:,:]
                scale2_0 = p1.data[0:self.rank,:]
                scale2_1 = p1.data[self.rank:,:]
                try:
                    grad2_0_scaled = grad2_0@torch.inverse(scale2_0@scale2_0.T+self.reg*torch.eye(self.rank).to(scale2_0.device))
                except:
                    grad2_0_scaled = grad2_0
                try:
                    grad2_1_scaled = grad2_1@torch.inverse(scale2_1@scale2_1.T+self.reg*torch.eye(self.rank).to(scale2_1.device))
                except:
                    grad2_1_scaled = grad2_1
                grad2_scaled = torch.cat([grad2_0_scaled, grad2_1_scaled])

                p1.data.add_(grad1_scaled, alpha=-group['lr'])
                p2.data.add_(grad2_scaled, alpha=-group['lr'])

                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])

class LoRAPro_SGD(Optimizer):
    def __init__(self, params, lr=1e-3, weight_decay=0.0, rank=4, reg=1e-6, alpha=32):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        self.scaling = 1
        print(f"scaling = {self.scaling}")
    
    def compute_eig(self, matrix):
        """
        compute the eigenvalue and eigenvectors, process the complex numbers
        
        para:
            matrix
            
        return:
            eigenvalues: complex eigenvalues
            eigenvectors: processed eigenvectors
        """
        # compute
        eigenvalues, eigenvectors = torch.eig(matrix, eigenvectors=True)
        
        # process
        complex_eigenvalues = torch.complex(eigenvalues[:, 0], eigenvalues[:, 1]).to(matrix.device)
        
        n = matrix.size(0)
        complex_eigenvectors = torch.zeros(n, n, dtype=torch.complex64).to(matrix.device)
        
        for i in range(n):
            if eigenvalues[i, 1] == 0:
                complex_eigenvectors[:, i] = torch.complex(eigenvectors[:, i], torch.zeros(n).to(matrix.device)).to(matrix.device)
            else: 
                
                if i < n-1 and torch.allclose(eigenvalues[i], torch.tensor([eigenvalues[i+1, 0], -eigenvalues[i+1, 1]])):
                    
                    complex_eigenvectors[:, i] = torch.complex(eigenvectors[:, i], eigenvectors[:, i+1]).to(matrix.device)
                    complex_eigenvectors[:, i+1] = torch.complex(eigenvectors[:, i], -eigenvectors[:, i+1]).to(matrix.device)
                    i += 1  
                else:
                    
                    complex_eigenvectors[:, i] = torch.complex(eigenvectors[:, i], torch.zeros(n).to(matrix.device)).to(matrix.device)
        
        return complex_eigenvalues, complex_eigenvectors
    
    def solve_sylvester(self, A, B, C, X=None):
        ''' 
        solve X: AX+XB=C
        From the answer here: 
        https://stackoverflow.com/questions/73713072/solving-sylvester-equations-in-pytorch

        For torch version >= 1.9, use torch.linglg.eig, torch.linalg.solve
        '''
        B = -B
        m = B.shape[-1]
        n = A.shape[-1]
        try:
            R, U = np.linalg.eig(A.cpu().numpy())
        except:
            print(A)
            R, U = np.linalg.eig((A + 1e-6 * torch.eye(A.shape[0])).cpu().numpy())

        S, V = np.linalg.eig(B.cpu().numpy())

        F = np.linalg.solve(U, (C.cpu()+0j) @ V)
        W = R[..., :, None] - S[..., None, :]
        Y = F / W
        X = (U[...,:n,:n] @ Y[...,:n,:m] @ np.linalg.inv(V)[...,:m,:m])
        return torch.from_numpy(X.real).to(A.device) if all(torch.isreal(x.flatten()[0]) 
                for x in [A, B, C]) else torch.from_numpy(X).to(A.device)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2 in list(zip(group['params'], group['params'][1:]))[::2]:
                # B/p2:m x r, A/p1: r x n
                # parameters
                dim_1 = p2.data.shape[0]//2
                MatrixA_0 = p1.data[0:self.rank,:]
                MatrixA_1 = p1.data[self.rank:,:]
                MatrixB_0 = p2.data[0:dim_1,:]
                MatrixB_1 = p2.data[dim_1:,:]
                grad_loraA_0 = p1.grad.data[0:self.rank,:]
                grad_loraA_1 = p1.grad.data[self.rank:,:]
                grad_loraB_0 = p2.grad.data[0:dim_1,:]
                grad_loraB_1 = p2.grad.data[dim_1:,:]

                state = self.state[p1]
                if len(state) == 0:
                    state["step"] = 0

                state["step"] += 1
                # compute (B^TB)^-1 and (AA^T)^-1
                if state["step"] == 1:
                    AA_T_0_inv = torch.inverse(MatrixA_0 @ MatrixA_0.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                    AA_T_1_inv = torch.inverse(MatrixA_1 @ MatrixA_1.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                    
                    delta_A_0 = grad_loraA_0
                    delta_A_1 = grad_loraA_1

                    delta_A = torch.cat([delta_A_0, delta_A_1])
                    assert delta_A.shape == p1.grad.data.shape

                    delta_B_0 = (1 / self.scaling ** 2) * grad_loraB_0 @  AA_T_0_inv
                    delta_B_1 = (1 / self.scaling ** 2) * grad_loraB_1 @  AA_T_1_inv

                    delta_B = torch.cat([delta_B_0, delta_B_1])
                    assert delta_B.shape == p2.grad.data.shape
                else:
                    AA_T_0_inv = torch.inverse(MatrixA_0 @ MatrixA_0.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                    AA_T_1_inv = torch.inverse(MatrixA_1 @ MatrixA_1.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                
                    B_TB_0_inv = torch.inverse(MatrixB_0.T @ MatrixB_0 + self.reg * torch.eye(self.rank).to(p1.data.device))
                    B_TB_1_inv = torch.inverse(MatrixB_1.T @ MatrixB_1 + self.reg * torch.eye(self.rank).to(p1.data.device))

                    # solve Sylvester equation
                    MatrixX_0 = self.solve_sylvester(MatrixB_0.T @ MatrixB_0, MatrixA_0 @ MatrixA_0.T, - (1 / self.scaling ** 2) * B_TB_0_inv @ grad_loraA_0 @ MatrixA_0.T)
                    MatrixX_1 = self.solve_sylvester(MatrixB_1.T @ MatrixB_1, MatrixA_1 @ MatrixA_1.T, - (1 / self.scaling ** 2) * B_TB_1_inv @ grad_loraA_1 @ MatrixA_1.T)
                    
                    # Update A/p1
                    # delta_A = (B^TB)^-1 @ gradA_tilde + XA
                    delta_A_0 = (1 / self.scaling ** 2) * B_TB_0_inv @ grad_loraA_0 + MatrixX_0 @ MatrixA_0
                    delta_A_1 = (1 / self.scaling ** 2) * B_TB_1_inv @ grad_loraA_1 + MatrixX_1 @ MatrixA_1

                    delta_A = torch.cat([delta_A_0, delta_A_1])
                    assert delta_A.shape == p1.grad.data.shape

                    # Update B/p2
                    # delta_B = [I - B(B^T B)^-1 B^T] gradB_tilde (AA^T)^-1 - BX
                    delta_B_0 = (1 / self.scaling ** 2) * grad_loraB_0 @ AA_T_0_inv - \
                        (1 / self.scaling ** 2) * MatrixB_0 @ B_TB_0_inv @ MatrixB_0.T @ grad_loraB_0 @ AA_T_0_inv - MatrixB_0 @ MatrixX_0
                    delta_B_1 = (1 / self.scaling ** 2) * grad_loraB_1 @ AA_T_1_inv - \
                        (1 / self.scaling ** 2) * MatrixB_1 @ B_TB_1_inv @ MatrixB_1.T @ grad_loraB_1 @ AA_T_1_inv - MatrixB_1 @ MatrixX_1

                    delta_B = torch.cat([delta_B_0, delta_B_1])
                    assert delta_B.shape == p2.grad.data.shape
                
                if group["weight_decay"] > 0.0:
                    p1.data = p1.data * math.sqrt(1-group["lr"] * group["weight_decay"])

                p1.data.add_(delta_A, alpha=-group["lr"])
                
                if group["weight_decay"] > 0.0:
                    p2.data = p2.data * math.sqrt(1-group["lr"] * group["weight_decay"])

                p2.data.add_(delta_B, alpha=-group["lr"])

        return loss
    
class SoLoRA_SGD(Optimizer):
    def __init__(self, params, lr, weight_decay, rank=4, precond_beta=0.98):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.rank = rank
        self.precond_beta = precond_beta

    def compute_inv_PSD(self, mat):
        # compute inverse of a positive semidefinite matrix
        zero_rows = torch.all(mat == 0, dim=1).any()
        if not zero_rows:
            return torch.inverse(mat)
        else:
            print(f"Warning: matrix {mat.shape} is not invertible, return identity matrix")
            return torch.eye(mat.shape[0]).to(mat.device)
        
    def step(self):
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                dim_1 = p2.data.shape[0]//2
                # Update A/p1
                gradA_0 = p1.grad.data[0:self.rank,:]
                gradA_1 = p1.grad.data[self.rank:,:]
                gradB_0 = p2.grad.data[0:dim_1,:]
                gradB_1 = p2.grad.data[dim_1:,:]
                MatrixA_0 = p1.data[0:self.rank,:]
                MatrixA_1 = p1.data[self.rank:,:]
                MatrixB_0 = p2.data[0:dim_1,:]
                MatrixB_1 = p2.data[dim_1:,:]

                G_W = None

                # hook
                module = param_to_module.get(id(p1))
                if module is not None and hasattr(module, 'W_new_grad'):
                    G_W = module.W_new_grad.to(p1.data.device)
                    if G_W is not None:
                        G_W = G_W.T[module.lora_ind]
                        G_W_0 = G_W[0:dim_1,:]
                        G_W_1 = G_W[dim_1:,:]
                    else:
                        print(f"W_new_grad is None for {module}")
                else:
                    print(f"No parent module found for parameters")

                # Accumulate pre_L and pre_R
                state = self.state[p1]
                # State Initialization
                if len(state) == 0:
                    state["pre_L_0"] = torch.zeros(dim_1).to(p1.data.device)
                    state["pre_L_1"] = torch.zeros(dim_1).to(p1.data.device)
                    state["pre_R_0"] = torch.zeros(MatrixA_0.shape[1]).to(p1.data.device)
                    state["pre_R_1"] = torch.zeros(MatrixA_1.shape[1]).to(p1.data.device)

                pre_L_0 = state["pre_L_0"]
                pre_L_1 = state["pre_L_1"]
                pre_R_0 = state["pre_R_0"]
                pre_R_1 = state["pre_R_1"]

                pre_L_0.mul_(self.precond_beta).add_(torch.einsum('ij,ij->i', G_W_0, G_W_0), alpha=1.0-self.precond_beta)
                pre_L_1.mul_(self.precond_beta).add_(torch.einsum('ij,ij->i', G_W_1, G_W_1), alpha=1.0-self.precond_beta)
                pre_R_0.mul_(self.precond_beta).add_(torch.einsum('ij,ij->j', G_W_0, G_W_0), alpha=1.0-self.precond_beta)
                pre_R_1.mul_(self.precond_beta).add_(torch.einsum('ij,ij->j', G_W_1, G_W_1), alpha=1.0-self.precond_beta)

                traceL_0 = torch.sum(pre_L_0)
                traceL_1 = torch.sum(pre_L_1)

                B_TLB_0_inv = self.compute_inv_PSD(MatrixB_0.T@torch.diag(pre_L_0**0.5)@MatrixB_0)
                B_TLB_1_inv = self.compute_inv_PSD(MatrixB_1.T@torch.diag(pre_L_1**0.5)@MatrixB_1)

                ARA_T_0_inv = torch.inverse(MatrixA_0@torch.diag(pre_R_0**0.5)@MatrixA_0.T)
                ARA_T_1_inv = torch.inverse(MatrixA_1@torch.diag(pre_R_1**0.5)@MatrixA_1.T)

                # Update A/p1
                delta_A_part_0 = B_TLB_0_inv @ gradA_0 * (traceL_0**0.5)
                delta_A_part_1 = B_TLB_1_inv @ gradA_1 * (traceL_1**0.5)
                # delta_A_part * pre_R^{-0.25} - 0.5 * delta_A_part * A^T * (A*pre_R^{0.25}*A^T)^{-1} * A
                delta_A_0 = delta_A_part_0 @ torch.diag(pre_R_0**(-0.5)) - 0.5*delta_A_part_0@MatrixA_0.T@ARA_T_0_inv@MatrixA_0
                delta_A_1 = delta_A_part_1 @ torch.diag(pre_R_1**(-0.5)) - 0.5*delta_A_part_1@MatrixA_1.T@ARA_T_1_inv@MatrixA_1

                delta_A = torch.cat([delta_A_0, delta_A_1])
                assert delta_A.shape == p1.grad.data.shape

                # Update B/p2
                delta_B_part_0 = gradB_0 @ ARA_T_0_inv * (traceL_0**0.5)
                delta_B_part_1 = gradB_1 @ ARA_T_1_inv * (traceL_1**0.5)

                # pre_L^{-0.25} * delta_B_part - 0.5 * B * (B^T*pre_L^{0.25}*B)^{-1} * B^T * delta_B_part
                delta_B_0 = torch.diag(pre_L_0**(-0.5)) @ delta_B_part_0 - 0.5*MatrixB_0@B_TLB_0_inv@MatrixB_0.T@delta_B_part_0
                delta_B_1 = torch.diag(pre_L_1**(-0.5)) @ delta_B_part_1 - 0.5*MatrixB_1@B_TLB_1_inv@MatrixB_1.T@delta_B_part_1
                
                delta_B = torch.cat([delta_B_0, delta_B_1])
                assert delta_B.shape == p2.grad.data.shape

                p1.data.add_(delta_A, alpha=-group['lr'])
                p2.data.add_(delta_B, alpha=-group['lr'])

                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])
