import torch
from torch.optim import AdamW

class NormalizedAdamW(AdamW):
    def __init__(self, *args, **kwargs):
        super(NormalizedAdamW, self).__init__(*args, **kwargs)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue


                state = self.state[p]


                if 'exp_avg' not in state:
                    continue

                exp_avg_sq = state['exp_avg_sq']  # v_t


                norm = torch.norm(exp_avg_sq, p=2)


                if norm > 0:
                    state['exp_avg_sq'] = exp_avg_sq / norm 

        super(NormalizedAdamW, self).step(closure)
        
        return loss

import torch
from torch.optim import AdamW

class ClippedAdamW_norm(AdamW):
    def __init__(self, *args, **kwargs):
        super(ClippedAdamW_norm, self).__init__(*args, **kwargs)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

 
                state = self.state[p]

               
                if 'exp_avg_sq' not in state:
                    continue

                exp_avg_sq = state['exp_avg_sq']  # v_t

                
                norm = torch.norm(exp_avg_sq, p=2)

              
                if norm < 1 and norm > 0:
                    state['exp_avg_sq'] = exp_avg_sq / norm  

       
        super(ClippedAdamW_norm, self).step(closure)
        
        return loss


class ClippedAdamW(AdamW):
    def __init__(self, *args, clip_threshold=1e-2, **kwargs):
       
        super(ClippedAdamW, self).__init__(*args, **kwargs)
        self.clip_threshold = clip_threshold

    def step(self, closure=None):
        
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                
                state = self.state[p]

                
                if 'exp_avg_sq' not in state:
                    continue

                exp_avg_sq = state['exp_avg_sq']  # v_t

                
                state['exp_avg_sq'] = torch.clamp(exp_avg_sq, min=self.clip_threshold)

        
        super(ClippedAdamW, self).step(closure)
        
        return loss