from torch.optim.optimizer import Optimizer, required
import torch
import math
import copy
import time
    
class SoLoRA_SGD(Optimizer):
    def __init__(self, params, lr, weight_decay, rank=4, betas=(0.9, 0.98)):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.rank = rank
        self.precond_beta = betas[1]
        # Allow training loop to attach external grads without hasattr checks failing
        self.external_grads = {}
        self.reg = 1e-6

    def compute_inv_PSD(self, mat):
        # compute inverse of a positive semidefinite matrix
        zero_rows = torch.all(mat == 0, dim=1).any()
        if not zero_rows:
            return torch.inverse(mat)
        else:
            print(f"Warning: matrix {mat.shape} is not invertible, return identity matrix")
            return torch.eye(mat.shape[0]).to(mat.device)
        
    def step(self, closure=None):
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                # p1: matrix A, r x n, p2: matrix B, m x r
                # Update A/p1
                if p1.grad is None:
                    continue
                gradA = p1.grad.data
                gradB = p2.grad.data
                MatrixA = p1.data
                MatrixB = p2.data

                G_W = None
                # hook
                external_grads = self.external_grads
                G_W = external_grads[id(p2)]
                if G_W is None:
                    raise RuntimeError("Missing gradient: G_W is None.")

                # Accumulate pre_L and pre_R
                state = self.state[p1]
                # State Initialization
                if len(state) == 0:
                    # Initialize with small positive values to avoid numerical issues
                    state["pre_L"] = torch.zeros(MatrixB.shape[0]).to(p1.data.device)
                    state["pre_R"] = torch.zeros(MatrixA.shape[1]).to(p1.data.device)

                pre_L = state["pre_L"]
                pre_R = state["pre_R"]

                pre_L_update = torch.einsum('ij,ij->i', G_W, G_W)
                pre_R_update = torch.einsum('ij,ij->j', G_W, G_W)
                
                # Clip extreme values to prevent inf
                pre_L_update = torch.clamp(pre_L_update, min=1e-6, max=1e6)
                pre_R_update = torch.clamp(pre_R_update, min=1e-6, max=1e6)
                
                # Update with numerical stability
                pre_L.mul_(self.precond_beta).add_(pre_L_update, alpha=1.0-self.precond_beta)
                pre_R.mul_(self.precond_beta).add_(pre_R_update, alpha=1.0-self.precond_beta)
                
                # Clamp pre_L and pre_R to prevent inf/nan
                pre_L.clamp_(min=1e-6, max=1e6)
                pre_R.clamp_(min=1e-6, max=1e6)
                
                traceL = torch.sum(pre_L)
                traceL.clamp_(min=1e-6, max=1e6)

                B_TLB_inv = self.compute_inv_PSD(MatrixB.T@torch.diag(pre_L**0.5)@MatrixB)

                ARA_T_inv = torch.inverse(MatrixA@torch.diag(pre_R**0.5)@MatrixA.T)

                # Update A/p1
                delta_A_part = B_TLB_inv @ gradA * (traceL**0.5)
                # delta_A_part * pre_R^{-0.25} - 0.5 * delta_A_part * A^T * (A*pre_R^{0.25}*A^T)^{-1} * A
                delta_A = delta_A_part @ torch.diag(pre_R**(-0.5)) - 0.5*delta_A_part@MatrixA.T@ARA_T_inv@MatrixA

                assert delta_A.shape == p1.grad.data.shape

                # Update B/p2
                delta_B_part = gradB @ ARA_T_inv * (traceL**0.5)

                # pre_L^{-0.25} * delta_B_part - 0.5 * B * (B^T*pre_L^{0.25}*B)^{-1} * B^T * delta_B_part
                delta_B = torch.diag(pre_L**(-0.5)) @ delta_B_part - 0.5*MatrixB@B_TLB_inv@MatrixB.T@delta_B_part

                assert delta_B.shape == p2.grad.data.shape

                p1.data.add_(delta_A, alpha=-group['lr'])
                p2.data.add_(delta_B, alpha=-group['lr'])

                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])

class LoRAPro_SGD(Optimizer):
    def __init__(self, params, lr=1e-3, weight_decay=0.0, rank=4, reg=1e-6, alpha=4, betas=(0.9, 0.98)):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        self.scaling = 1
        print(f"scaling = {self.scaling}")
    
    def solve_sylvester(self, A, B, C, X=None):
        ''' 
        solve X: AX+XB=C
        From the answer here: 
        https://stackoverflow.com/questions/73713072/solving-sylvester-equations-in-pytorch

        '''
        B = -B
        m = B.shape[-1]
        n = A.shape[-1]
        try:
            R, U = torch.linalg.eig(A)
        except:
            print(A)
            R, U = torch.linalg.eig(A + 1e-6 * torch.eye(A.shape[0]))

        S, V = torch.linalg.eig(B)

        F = torch.linalg.solve(U, (C+0j) @ V)
        W = R[..., :, None] - S[..., None, :]
        Y = F / W
        X = U[...,:n,:n] @ Y[...,:n,:m] @ torch.linalg.inv(V)[...,:m,:m]
        return X.real if all(torch.isreal(x.flatten()[0]) 
                for x in [A, B, C]) else X

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2 in list(zip(group['params'], group['params'][1:]))[::2]:
                # B/p2:m x r, A/p1: r x n
                # parameters
                MatrixA = p1.data
                MatrixB = p2.data
                grad_loraA = p1.grad.data
                grad_loraB = p2.grad.data

                state = self.state[p1]
                if len(state) == 0:
                    state["step"] = 0

                state["step"] += 1
                # compute (B^TB)^-1 and (AA^T)^-1
                if state["step"] == 1:
                    AA_T_inv = torch.inverse(MatrixA @ MatrixA.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                    delta_A = grad_loraA

                    assert delta_A.shape == p1.grad.data.shape

                    delta_B = (1 / self.scaling ** 2) * grad_loraB @  AA_T_inv

                    assert delta_B.shape == p2.grad.data.shape
                else:
                    AA_T_inv = torch.inverse(MatrixA @ MatrixA.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                
                    B_TB_inv = torch.inverse(MatrixB.T @ MatrixB + self.reg * torch.eye(self.rank).to(p1.data.device))

                    # solve Sylvester equation
                    MatrixX = self.solve_sylvester(MatrixB.T @ MatrixB, MatrixA @ MatrixA.T, \
                                                     - (1 / self.scaling ** 2) * B_TB_inv @ grad_loraA @ MatrixA.T)
                   
                    # Update A/p1
                    # delta_A = (B^TB)^-1 @ gradA_tilde + XA
                    delta_A = (1 / self.scaling ** 2) * B_TB_inv @ grad_loraA + MatrixX @ MatrixA

                    assert delta_A.shape == p1.grad.data.shape

                    # Update B/p2
                    # delta_B = [I - B(B^T B)^-1 B^T] gradB_tilde (AA^T)^-1 - BX
                    delta_B = (1 / self.scaling ** 2) * grad_loraB @ AA_T_inv - \
                        (1 / self.scaling ** 2) * MatrixB @ B_TB_inv @ MatrixB.T @ grad_loraB @ AA_T_inv - MatrixB @ MatrixX

                    assert delta_B.shape == p2.grad.data.shape
                
                if group["weight_decay"] > 0.0:
                    p1.data = p1.data * math.sqrt(1-group["lr"] * group["weight_decay"])

                p1.data.add_(delta_A, alpha=-group["lr"])
                
                if group["weight_decay"] > 0.0:
                    p2.data = p2.data * math.sqrt(1-group["lr"] * group["weight_decay"])

                p2.data.add_(delta_B, alpha=-group["lr"])

        return loss
    
class SGDr(Optimizer):
    def __init__(self, params, lr, weight_decay, betas=(0.9, 0.98), eps=1e-6, correct_bias=True, reg=0):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.reg = reg 
        print(f'{self.reg=}')
    def step(self, closure=None):
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                grad1 = p1.grad.data
                scale1 = p2.data
                try:
                    grad1_scaled = torch.inverse(scale1.T@scale1+self.reg*torch.eye(scale1.shape[1]).to(scale1.device))@grad1
                except:
                    grad1_scaled = grad1
                
                grad2 = p2.grad.data
                scale2 = p1.data
                try:
                    grad2_scaled = grad2@torch.inverse(scale2@scale2.T+self.reg*torch.eye(scale2.shape[0]).to(scale2.device))
                except:
                    grad2_scaled = grad2
                
                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])

                p1.data.add_(grad1_scaled, alpha=-group['lr'])
                p2.data.add_(grad2_scaled, alpha=-group['lr'])

class SGD(Optimizer):
    def __init__(self, params, lr, weight_decay, betas=(0.9, 0.98), eps=1e-6, correct_bias=True):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
    def step(self,closure=None):
        for group in self.param_groups:
            for p in group["params"]:
                if group["weight_decay"] > 0.0:
                    p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
                p.data.add_(p.grad.data, alpha=-group['lr'])
            

class SoLoRA(Optimizer):
    def __init__(self, params, lr, weight_decay, rank=4, eps=1e-6, betas=(0.9, 0.98)):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.rank = rank
        # Allow training loop to attach external grads without hasattr checks failing
        self.betas = betas
        self.external_grads = {}
        self.reg = 1e-6
        self.eps = 0
        
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                # p1: matrix A, r x n, p2: matrix B, m x r
                # Update A/p1
                MatrixA = p1.data
                MatrixB = p2.data

                G_W = None
                # hook
                external_grads = self.external_grads
                G_W = external_grads[id(p2)]
                if G_W is None:
                    raise RuntimeError("Missing gradient: G_W is None.")
                

                # Accumulate pre_L and pre_R
                state = self.state[p1]
                # State Initialization
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(G_W).to(p1.data.device)
                    state["pre_L"] = torch.zeros(MatrixB.shape[0]).to(p1.data.device)
                    state["pre_R"] = torch.zeros(MatrixA.shape[1]).to(p1.data.device)

                state["step"] += 1
                exp_avg = state["exp_avg"]
                beta1, beta2 = self.betas
                pre_L = state["pre_L"]
                pre_R = state["pre_R"]

                exp_avg.mul_(beta1).add_(G_W, alpha=1.0-beta1)
                
                # Update with numerical stability
                pre_L.mul_(beta2).add_(torch.einsum('ij,ij->i', G_W, G_W), alpha=1.0-beta2)
                pre_R.mul_(beta2).add_(torch.einsum('ij,ij->j', G_W, G_W), alpha=1.0-beta2)

                
                traceL = torch.sum(pre_L)

                # B_TLB_inv = self.compute_inv_PSD(MatrixB.T@torch.diag(pre_L**0.5)@MatrixB)
                pre_L_root_inv = (pre_L**0.5 + self.eps * torch.ones(MatrixB.shape[0]).to(p1.data.device))**(-1)
                pre_R_root_inv = (pre_R**0.5 + self.eps * torch.ones(MatrixA.shape[1]).to(p1.data.device))**(-1)
                try:
                    B_TLB_inv = torch.inverse(MatrixB.T@torch.diag(pre_L**0.5)@MatrixB + self.reg * torch.eye(MatrixB.shape[1]).to(p1.data.device))
                except:
                    B_TLB_inv = torch.diag(pre_L_root_inv[0:MatrixB.shape[1]])
                try:
                    ARA_T_inv = torch.inverse(MatrixA@torch.diag(pre_R**0.5)@MatrixA.T + self.reg * torch.eye(MatrixA.shape[0]).to(p1.data.device))
                except:
                    ARA_T_inv = torch.diag(pre_R_root_inv[0:MatrixA.shape[0]])

                # Update A/p1
                delta_A_part = B_TLB_inv @ MatrixB.T @ exp_avg * (traceL**0.5)
                # delta_A_part * pre_R^{-0.25} - 0.5 * delta_A_part * A^T * (A*pre_R^{0.25}*A^T)^{-1} * A
                delta_A = delta_A_part @ torch.diag(pre_R_root_inv) - 0.5*delta_A_part@MatrixA.T@ARA_T_inv@MatrixA

                assert delta_A.shape == p1.grad.data.shape

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                # Update B/p2
                delta_B_part = exp_avg @ MatrixA.T @ ARA_T_inv * (traceL**0.5)

                # pre_L^{-0.25} * delta_B_part - 0.5 * B * (B^T*pre_L^{0.25}*B)^{-1} * B^T * delta_B_part
                delta_B = torch.diag(pre_L_root_inv) @ delta_B_part - 0.5*MatrixB@B_TLB_inv@MatrixB.T@delta_B_part
                
                assert delta_B.shape == p2.grad.data.shape

                p1.data.add_(delta_A, alpha=-group['lr'])
                p2.data.add_(delta_B, alpha=-group['lr'])

                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])

        return loss

class LoRAPro_AdamW(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True, rank=4, reg=1e-8, alpha=4):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)
        self.rank = rank
        self.reg = reg
        self.eps = eps
        self.scaling = alpha / rank
        print(f"self.reg = {self.reg}")
        print(f"scaling = {self.scaling}")
    
    def solve_sylvester(self, A, B, C, X=None):
        ''' 
        solve X: AX+XB=C
        From the answer here: 
        https://stackoverflow.com/questions/73713072/solving-sylvester-equations-in-pytorch

        '''
        B = -B
        m = B.shape[-1]
        n = A.shape[-1]
        try:
            R, U = torch.linalg.eig(A)
        except:
            print(A)
            R, U = torch.linalg.eig(A + 1e-6 * torch.eye(A.shape[0]))

        S, V = torch.linalg.eig(B)

        F = torch.linalg.solve(U, (C+0j) @ V)
        W = R[..., :, None] - S[..., None, :]
        Y = F / W
        X = U[...,:n,:n] @ Y[...,:n,:m] @ torch.linalg.inv(V)[...,:m,:m]
        return X.real if all(torch.isreal(x.flatten()[0]) 
                for x in [A, B, C]) else X
    
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            param_to_module = group.get('param_to_module', {})
            for p1, p2 in list(zip(group['params'], group['params'][1:]))[::2]:
                # B/p2:m x r, A/p1: r x n
                # parameters
                MatrixA = p1.data
                MatrixB = p2.data
                grad_loraA = p1.grad.data
                grad_loraB = p2.grad.data

                beta1, beta2 = group["betas"]

                state = self.state[p1]
                if len(state) == 0:
                    state["step"] = 0
                    
                state["step"] += 1
                # compute (B^TB)^-1 and (AA^T)^-1
                if state["step"] == 1:
                    AA_T_inv = torch.inverse(MatrixA @ MatrixA.T + self.reg * torch.eye(self.rank).to(p1.data.device))
                    
                    gradA = grad_loraA

                    gradB = (1 / self.scaling ** 2) * grad_loraB @ AA_T_inv
                    
                else:
                    B_TB_inv = torch.inverse(MatrixB.T @ MatrixB + self.reg * torch.eye(self.rank).to(p1.data.device))
                    
                    AA_T_inv = torch.inverse(MatrixA @ MatrixA.T + self.reg * torch.eye(self.rank).to(p1.data.device))

                    # compute gradA = (B^TB)^-1 @ grad_loraA, gradB = [I-B(B^TB)^-1 @ B^T] @ grad_loraB @ (AA^T)^-1
                    gradA = (1 / self.scaling ** 2) * B_TB_inv @ grad_loraA 
                
                    gradB = (1 / self.scaling ** 2) * grad_loraB @ AA_T_inv - (1 / self.scaling ** 2) * MatrixB @ B_TB_inv @ MatrixB.T @ grad_loraB @ AA_T_inv
               
                # compute gradW = gradB @ A + B @ gradA
                gradW = self.scaling * (gradB @ MatrixA + MatrixB @ gradA)

                # initialize
                if state["step"] == 1:
                    state["exp_avg"] = (1 - beta1) * gradW
                    state["exp_avg_sq"] = (1 - beta2) * gradW * gradW
                    exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                else:
                    exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                    exp_avg.mul_(beta1).add_(gradW, alpha=1.0 - beta1)
                    exp_avg_sq.mul_(beta2).addcmul_(gradW, gradW, value=1.0 - beta2)
                
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]

                # m_t / (1-beta1^t) / (sqrt[v_t / (1-beta2^t)] + eps)
                denom = torch.div(exp_avg_sq, bias_correction2).sqrt().add_(group["eps"])
                gradW_tilde = torch.div(exp_avg, bias_correction1 * denom)

                # calculate gradA and gradB again
                # gradA_tilde = B.T @ gradW_tilde, gradB_tilde = gradW_tilde @ A.T
                gradA_tilde = self.scaling * MatrixB.T @ gradW_tilde

                gradB_tilde = self.scaling * gradW_tilde @ MatrixA.T

                # solve Sylvester equation
                if state["step"] == 1:
                    delta_A = gradA_tilde
                    delta_B = (1 / self.scaling ** 2) * gradB_tilde @  AA_T_inv

                    assert delta_A.shape == p1.grad.data.shape

                    assert delta_B.shape == p2.grad.data.shape
                else:
                    MatrixX = self.solve_sylvester(MatrixB.T @ MatrixB, MatrixA @ MatrixA.T, \
                                                   - (1 / self.scaling ** 2) * B_TB_inv @ gradA_tilde @ MatrixA.T)

                    # Update A/p1
                    # delta_A = (B^TB)^-1 @ gradA_tilde + XA
                    delta_A = (1 / self.scaling ** 2) * B_TB_inv @ gradA_tilde + MatrixX @ MatrixA

                    assert delta_A.shape == p1.grad.data.shape

                    # Update B/p2
                    # delta_B = [I - B(B^T B)^-1 B^T] gradB_tilde (AA^T)^-1 - BX
                    delta_B = (1 / self.scaling ** 2) * gradB_tilde @ AA_T_inv - \
                        (1 / self.scaling ** 2) * MatrixB @ B_TB_inv @ MatrixB.T @ gradB_tilde @ AA_T_inv - MatrixB @ MatrixX

                    assert delta_B.shape == p2.grad.data.shape

                
                if group["weight_decay"] > 0.0:
                    p1.data = p1.data * math.sqrt(1-group["lr"] * group["weight_decay"])

                p1.data.add_(delta_A, alpha=-group["lr"])

                if group["weight_decay"] > 0.0:
                    p2.data = p2.data * math.sqrt(1-group["lr"] * group["weight_decay"])

                p2.data.add_(delta_B, alpha=-group["lr"])

        return loss
    

class AdamWr(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=False, reg=0):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)
        self.reg = reg
        print(f'{self.reg=}')
    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state["exp_avg"] = torch.zeros_like(p.data)
                state["exp_avg_sq"] = torch.zeros_like(p.data)
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
                state = self.state[p1]
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p1.data)
                    state["exp_avg_sq"] = torch.zeros_like(p1.data)
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]
                state["step"] += 1
                grad1 = p1.grad.data
                c = p2.data
                try:
                    c_ = torch.inverse(c.T@c+self.reg*torch.eye(c.shape[1]).to(c.device))
                except:
                    c_ = torch.eye((c.T@c).shape[0]).to(c.device)
                grad1_scaled = c_@grad1
                assert grad1_scaled.shape == p1.grad.data.shape


                exp_avg.mul_(beta1).add_(grad1_scaled, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad1_scaled, grad1_scaled, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                c1 = p1.data

                p1.data.addcdiv_(-step_size, exp_avg, denom)
                if group["weight_decay"] > 0.0:
                    p1.data.add_(p1.data, alpha=-group["lr"] * group["weight_decay"])

                
                state = self.state[p2]
                
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p2.data)
                    state["exp_avg_sq"] = torch.zeros_like(p2.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]
                state["step"] += 1

                grad2 = p2.grad.data
                try:
                    c1_ = torch.inverse(c1@c1.T+self.reg*torch.eye(c1.shape[0]).to(c1.device))
                except:
                    c1_ = torch.eye((c1@c1.T).shape[0]).to(c1.device)
                
                grad2_scaled = grad2@c1_
                assert grad2_scaled.shape == p2.grad.data.shape
                
                exp_avg.mul_(beta1).add_(grad2_scaled, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad2_scaled, grad2_scaled, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p2.data.addcdiv_(-step_size, exp_avg, denom)
                if group["weight_decay"] > 0.0:
                    p2.data.add_(p2.data, alpha=-group["lr"] * group["weight_decay"])
                
        return loss
    
class AdamW(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=False):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)

    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state["exp_avg"] = torch.zeros_like(p.data)
                state["exp_avg_sq"] = torch.zeros_like(p.data)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p.data)
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if 'correct_bias' in group and group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)
                if group["weight_decay"] > 0.0:
                    p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])

        return loss