import torch
from .util import enable_running_stats, disable_running_stats
import contextlib
from torch.distributed import ReduceOp


# # 原始sagm + 量化s参与扰动
# class SAGM(torch.optim.Optimizer):
#     def __init__(self, params, base_optimizer, model, alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', q_base_optimizer = None, **kwargs):
#         defaults = dict(adaptive=adaptive, **kwargs)
#         super(SAGM, self).__init__(params, defaults)
#         self.model = model
#         self.base_optimizer = base_optimizer
#         self.param_groups = self.base_optimizer.param_groups
#         self.q_base_optimizer = q_base_optimizer
#         if self.q_base_optimizer != None:
#             self.q_param_groups = self.q_base_optimizer.param_groups
#         self.adaptive = adaptive
#         self.rho_scheduler = rho_scheduler
#         self.perturb_eps = perturb_eps
#         self.alpha = alpha
#         self.update_q = 0
#         # initialize self.rho_t
#         self.update_rho_t()
        
#         # set up reduction for gradient across workers
#         if grad_reduce.lower() == 'mean':
#             if hasattr(ReduceOp, 'AVG'):
#                 self.grad_reduce = ReduceOp.AVG
#                 self.manual_average = False
#             else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
#                 self.grad_reduce = ReduceOp.SUM
#                 self.manual_average = True
#         elif grad_reduce.lower() == 'sum':
#             self.grad_reduce = ReduceOp.SUM
#             self.manual_average = False
#         else:
#             raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
    
#     @torch.no_grad()
#     def update_rho_t(self):
#         self.rho_t = self.rho_scheduler.step()
#         return self.rho_t

#     @torch.no_grad()
#     def perturb_weights(self, rho=0.0):
        
#         grad_norm = self._grad_norm( weight_adaptive = self.adaptive )
#         for group in self.param_groups:
#             # print(group)
#             scale = (rho / (grad_norm + self.perturb_eps) - self.alpha)

#             for p in group["params"]:
                
#                 if p.grad is None: continue
#                 self.state[p]["old_g"] = p.grad.data.clone()
#                 e_w = p.grad * scale.to(p)
#                 if self.adaptive:
#                     e_w *= torch.pow(p, 2)
#                 p.add_(e_w)  # climb to the local maximum "w + e(w)"
#                 self.state[p]['e_w'] = e_w
                
#         if self.q_base_optimizer != None:
#             grad_norm = self._grad_norm_s( weight_adaptive = self.adaptive )
#             for group in self.q_param_groups:
#                 # print(group)
#                 scale = (rho / (grad_norm + self.perturb_eps) - self.alpha)

#                 for p in group["params"]:
                    
#                     if p.grad is None: continue
#                     self.state[p]["old_g"] = p.grad.data.clone()
#                     e_w = p.grad * scale.to(p)
#                     if self.adaptive:
#                         e_w *= torch.pow(p, 2)
#                     p.add_(e_w)  # climb to the local maximum "w + e(w)"
#                     self.state[p]['e_w'] = e_w
#             # for group in self.q_param_groups:
#             #     for p in group['params']:
#             #         if p.grad is None: continue
#             #         self.state[p]["old_g"] = p.grad.data.clone()
#             #         # print("s", p)

#     @torch.no_grad()
#     def unperturb(self):
#         # 需要修改，修改成原始梯度，一开始传入的应该是量化之后反量化的梯度
#         for group in self.param_groups:
#             for p in group['params']:
#                 if 'e_w' in self.state[p].keys():
#                     p.data.sub_(self.state[p]['e_w'])

#         if self.q_base_optimizer != None:
#             for group in self.q_param_groups:
#                 for p in group['params']:
#                     if 'e_w' in self.state[p].keys():
#                         p.data.sub_(self.state[p]['e_w'])


#     @torch.no_grad()
#     def gradient_decompose(self, alpha=0.0):
#         # 这里需要好好看看
#         for group in self.param_groups:
#             for p in group['params']:
#                 if p.grad is None: continue
#                 sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
#                 p.grad.data.add_(sam_grad)
#         if self.q_base_optimizer != None:
#             for group in self.q_param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                    
#                     # if (self.state[p]['old_g'] > 0 and p.grad > 0) or (self.state[p]['old_g'] < 0 and p.grad < 0):
#                     #     # sam_grad = -p.grad + torch.sqrt(self.state[p]['old_g'] * p.grad)
#                     #     sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
#                     # else:
#                     #     # sam_grad = -p.grad + self.state[p]['old_g']
#                     #     sam_grad = -p.grad
#                     # sam_grad = self.state[p]['old_g'] - p.grad
#                     # print("old", self.state[p]['old_g'], "new", p.grad)
#                     p.grad.data.add_(sam_grad)
#         # print(self.old_grads)
#         # print("____________")
#         # print(self.new_grads)
        
        
#         # # toy
#         # tensor_old = torch.unsqueeze(self.old_grads[1], 0)
#         # tensor_new = torch.unsqueeze(self.new_grads[1], 0)
#         # vector_old = torch.cat((self.old_grads[0], tensor_old), dim=1)
#         # vector_new = torch.cat((self.new_grads[0], tensor_new), dim=1)
#         # dot_product = torch.dot(vector_old.view(-1), vector_new.view(-1))
#         # # 计算向量的模
#         # norm1 = torch.norm(vector_old)
#         # norm2 = torch.norm(vector_new)

#         # # 计算余弦值
#         # cosine_sim = dot_product / (norm1 * norm2)
#         # self.cos_sims.append(cosine_sim.item()) 

#         # cos_sim = torch.cosine_similarity(torch.cat((self.old_grads[0], tensor_old), dim=1), torch.cat((self.new_grads[0], tensor_new), dim=1))
        
#         # print("cos_sim:", cosine_sim)
#     @torch.no_grad()
#     def _sync_grad(self):
#         if torch.distributed.is_initialized(): # synchronize final gardients
#             for group in self.param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     if self.manual_average:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#                         world_size = torch.distributed.get_world_size()
#                         p.grad.div_(float(world_size))
#                     else:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#             if self.q_base_optimizer != None:
#                 for group in self.q_param_groups:
#                     for p in group['params']:
#                         if p.grad is None: continue
#                         if self.manual_average:
#                             torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#                             world_size = torch.distributed.get_world_size()
#                             p.grad.div_(float(world_size))
#                         else:
#                             torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#         return

#     @torch.no_grad()
#     def _grad_norm(self, by=None, weight_adaptive=False):
#         #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
#         if not by:

#             norm = torch.norm(
#                     torch.stack([
#                         ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
#                         for group in self.param_groups for p in group["params"]
#                         if p.grad is not None
#                     ]),
#                     p=2
#                )

#         else:

#             norm = torch.norm(
#                 torch.stack([
#                     ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
#                     for group in self.param_groups for p in group["params"]
#                     if p.grad is not None
#                 ]),
#                 p=2
#             )

#         return norm


#     @torch.no_grad()
#     def _grad_norm_s(self, by=None, weight_adaptive=False):
#         #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
#         if not by:

#             norm = torch.norm(
#                     torch.stack([
#                         ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
#                         for group in self.q_param_groups for p in group["params"]
#                         if p.grad is not None
#                     ]),
#                     p=2
#                )

#         else:

#             norm = torch.norm(
#                 torch.stack([
#                     ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
#                     for group in self.q_param_groups for p in group["params"]
#                     if p.grad is not None
#                 ]),
#                 p=2
#             )

#         return norm


#     # def norm(tensor_list: List[torch.tensor], p=2):
#     #     """Compute p-norm for tensor list"""
#     #     return torch.cat([x.flatten() for x in tensor_list]).norm(p)

#     def load_state_dict(self, state_dict):
#         super().load_state_dict(state_dict)
#         self.base_optimizer.param_groups = self.param_groups
        
#     def maybe_no_sync(self):
#         if torch.distributed.is_initialized():
#             return self.model.no_sync()
#         else:
#             return contextlib.ExitStack()

#     @torch.no_grad()
#     def set_closure(self, loss_fn, inputs, targets, **kwargs):
#         # create self.forward_backward_func, which is a function such that
#         # self.forward_backward_func() automatically performs forward and backward passes.
#         # This function does not take any arguments, and the inputs and targets data
#         # should be pre-set in the definition of partial-function

#         def get_grad():
#             self.base_optimizer.zero_grad()
#             if self.q_base_optimizer != None:
#                 self.q_base_optimizer.zero_grad()
#             with torch.enable_grad():
#                 outputs = self.model(inputs)
#                 # print("1 2", outputs.shape, targets.shape)
#                 # print(outputs)
#                 loss = loss_fn(outputs, targets, **kwargs)
#             loss_value = loss.data.clone().detach()
#             loss.backward()

#             return outputs, loss_value

#         self.forward_backward_func = get_grad

#     @torch.no_grad()
#     def step(self, closure=None):

#         if closure:
#             get_grad = closure
#         else:
#             get_grad = self.forward_backward_func

#         with self.maybe_no_sync():
#             # get gradient
#             # 获得量化反量化之后运算的梯度
#             outputs, loss_value = get_grad()

#             # perturb weights
#             self.perturb_weights(rho=self.rho_t)

#             # disable running stats for second pass
#             disable_running_stats(self.model)

#             # get gradient at perturbed weights
#             get_grad()

#             # decompose and get new update direction
#             self.gradient_decompose(self.alpha)

#             # unperturb
#             self.unperturb()
            
#         # synchronize gradients across workers
#         self._sync_grad()    
#         # for name, p in self.model.named_parameters():
#         #     if 'layer1.0.conv1.weight' in name:
#         #         print("Gradient1", p.grad[0])
#         #     elif 'layer2.0.conv2.weight' in name:
#         #         print("Gradient2", p.grad[0])
#         #     elif 'layer4.0.conv3.weight' in name:
#         #         print("Gradient3", p.grad[0])
#         # update with new directions
#         # if self.update_q % 2 == 1:
#         #     self.base_optimizer.step()
#         # else:
#         #     self.q_base_optimizer.step()
#         # if self.q_base_optimizer != None and self.update_q  > 1000:
#         #     self.q_base_optimizer.step()
#         # self.update_q += 1
        
#         self.base_optimizer.step()
#         if self.q_base_optimizer != None:
#             self.q_base_optimizer.step()

#         # enable running stats
#         enable_running_stats(self.model)

#         return outputs, loss_value


import random
import numpy as np
import math
# 原始sagm + 量化s不参与扰动
class SAGM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, model, alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', q_base_optimizer = None, optimizer_q_fix_5000steps = None, s_grad = 'both', clip_range = 0,fix_reverse=False, fix_alter_steps=0, freeze_steps=200,freeze_old_ratio=0, sort_steps=400, freeze_new_ratio=0, st_freeze_steps=0, continue_freeze=True, no_unfreeze=False, reverse_freeze=False, **kwargs):
    # def __init__(self, params, base_optimizer, model, alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', q_base_optimizer = None, optimizer_q_fix_5000steps = None, s_grad = 'both', clip_range = 0,fix_reverse=False, fix_alter_steps=0,**kwargs):
    
        defaults = dict(adaptive=adaptive, **kwargs)
        super(SAGM, self).__init__(params, defaults)
        self.model = model
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.q_base_optimizer = q_base_optimizer
        if self.q_base_optimizer != None:
            self.q_param_groups = self.q_base_optimizer.param_groups
        self.optimizer_q_fix_5000steps = optimizer_q_fix_5000steps
        if self.optimizer_q_fix_5000steps != None:
            self.q_param_groups_fix_5000steps = self.optimizer_q_fix_5000steps.param_groups
        self.adaptive = adaptive
        self.rho_scheduler = rho_scheduler
        self.perturb_eps = perturb_eps
        self.alpha = alpha
        self.update_q = 0
        self.s_grad = s_grad
        self.clip_range = clip_range
        self.step_cnt = 0
        self.fix_reverse = fix_reverse
        self.fix_alter_steps = fix_alter_steps
        self.sum_steps = 0
        self.s_grad_sums_old = []
        self.s_grad_sums_new = []
        self.s_grad_sums_old_real = []
        self.s_grad_sums_new_real = []
        self.freeze_new_ratio = freeze_new_ratio
        
        
        self.clip_ratio = 0.01 # 代表有取3%和97%处为clip阈值
        self.out_ratio = 0.03
        
        
        
        
        
        self.freeze_old_ratio = freeze_old_ratio
        
        print("freeze_old_ratio", freeze_old_ratio)
        self.freeze_ratio = 0.03
        self.freeze = []
        self.freeze_old = []
        
        self.freeze_new = []
        self.freeze_same_threshold = 0.9
        self.freeze_reverse_threshold = 0.175
        self.freeze_reverse_threshold_old = 0.15
        self.freeze_reverse_threshold_new = 0.15
        self.freeze_reverse_ratio = 0.02
        self.freeze_reverse_ratio_old = 0.02
        self.freeze_reverse_ratio_new = 0.01
        
        self.clip5 = [1]
        self.clip10 = [17, 22 ,18, 42, 19, 25, 28,32, 41, 13, 36, 30, 0,35, 29, 23, 2, 6, 12, 20, 24, 4, 14, 16, 10, 3, 5, 8, 33, 27, 21, 11, 1, 9, 7, 15, 90, 88, 98, 94, 104, 100, 102, 96, 92, 86, 82, 72, 80, 68, 84, 78, 76, 70, 74, 66, 64, 73, 67, 79, 71, 77, 95, 83, 59, 60, 87, 54, 63, 69, 62, 61, 65, 91, 99, 39, 85, 56, 81, 101, 75, 89, 103, 93, 57, 97, 52, 55, 58, 51, 50, 45, 49, 31, 46, 37, 47, 48, 44, 43, 53, 40, 38, 34, 26]
        self.clip15 = []
        self.st_freeze_steps = st_freeze_steps
        self.ema_alpha = 0.75
        self.ema_multiple_num = 3
        self.sort_steps = freeze_steps
        self.freeze_steps = freeze_steps
        print("freeze_steps", freeze_steps)
        print("sort_steps ", self.sort_steps)
        print("ema_alpha", self.ema_alpha)
        print("ema_multiple_num", self.ema_multiple_num)
        self.s_length = 0
        for group in self.q_param_groups:
            for p in group['params']:
                self.s_length += 1
                self.s_grad_sums_old.append(0)
                self.s_grad_sums_old_real.append(0)
                self.s_grad_sums_new.append(0)
                self.s_grad_sums_new_real.append(0)
        
        self.freeze_length = 12
        self.freeze_times = [0 for i in range(self.s_length)]
        
        
        self.s_grad_new_list = np.zeros(( self.s_length, self.sort_steps))
        self.s_grad_old_list = np.zeros(( self.s_length, self.sort_steps))
        
        self.p_list = np.zeros(( self.s_length, self.sort_steps))

        
        self.ema_s_clip = [10000] * self.s_length
        self.s1_clip_range_pos = [10000] * self.s_length
        self.s1_clip_range_neg = [-10000] * self.s_length
        self.s2_clip_range_pos = [10000] * self.s_length
        self.s2_clip_range_neg = [-10000] * self.s_length
        self.outer_old = [100] * self.s_length
        self.outer_new = [100] * self.s_length
        self.sort_outer_old = [100] * self.s_length
        self.sort_outer_new = [100] * self.s_length
        self.reverse_ratio_list = [100] * self.s_length
        self.reverse_ratio_list_new = [100] * self.s_length
        self.reverse_ratio_list_old = [100] * self.s_length
        self.reverse_ratio_list_delta = [100] * self.s_length
        self.reverse_ratio_list_add = [100] * self.s_length
        self.sort_reverse = [100] * self.s_length
        self.continue_freeze = continue_freeze
        self.reverse_freeze = reverse_freeze
        self.no_unfreeze = no_unfreeze
        # self.s2_clip_range = [10000] * self.s_length
        # initialize self.rho_t
        self.update_rho_t()
        
        # set up reduction for gradient across workers
        if grad_reduce.lower() == 'mean':
            if hasattr(ReduceOp, 'AVG'):
                self.grad_reduce = ReduceOp.AVG
                self.manual_average = False
            else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
                self.grad_reduce = ReduceOp.SUM
                self.manual_average = True
        elif grad_reduce.lower() == 'sum':
            self.grad_reduce = ReduceOp.SUM
            self.manual_average = False
        else:
            raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
    
    @torch.no_grad()
    def update_rho_t(self):
        self.rho_t = self.rho_scheduler.step()
        return self.rho_t

    @torch.no_grad()
    def perturb_weights(self, rho=0.0):
        
        grad_norm = self._grad_norm( weight_adaptive = self.adaptive )
        for group in self.param_groups:
            # print(group)
            scale = (rho / (grad_norm + self.perturb_eps) - self.alpha)

            for p in group["params"]:
                
                if p.grad is None: continue
                self.state[p]["old_g"] = p.grad.data.clone()
                e_w = p.grad * scale.to(p)
                if self.adaptive:
                    e_w *= torch.pow(p, 2)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]['e_w'] = e_w
        if self.q_base_optimizer != None:
            for group in self.q_param_groups:
                for p in group['params']:
                    if p.grad is None: continue
                    self.state[p]["old_g"] = p.grad.data.clone()
        if self.optimizer_q_fix_5000steps != None:
            for group in self.q_param_groups_fix_5000steps:
                for p in group['params']:
                    if p.grad is None: continue
                    self.state[p]["old_g"] = p.grad.data.clone()
                    # print("s", p)

    @torch.no_grad()
    def unperturb(self):
        for group in self.param_groups:
            for p in group['params']:
                if 'e_w' in self.state[p].keys():
                    p.data.sub_(self.state[p]['e_w'])
    @torch.no_grad()
    def gradient_decompose(self, alpha=0.0):
        # 这里需要好好看看
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                # sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                p.grad.data.add_(sam_grad)
        if self.q_base_optimizer != None:

            if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
                
                
                if self.sum_steps == self.sort_steps:
                    for idx in range(self.s_length):
                        self.ema_s_clip[idx] = torch.abs(self.ema_multiple_num * (0.5 * self.s_grad_sums_old_real[idx] + 0.5 * self.s_grad_sums_new_real[idx]) / self.sort_steps)
                        print(idx, "ema clip ", self.ema_s_clip[idx])
                else:
                    for idx in range(self.s_length):
                        self.ema_s_clip[idx] = (1-self.ema_alpha) * self.ema_multiple_num * torch.abs(0.5 * self.s_grad_sums_old_real[idx] + 0.5 * self.s_grad_sums_new_real[idx])/ self.sort_steps + self.ema_alpha * self.ema_s_clip[idx]
                        print(idx, "ema clip ", self.ema_s_clip[idx])
                        
                  
                # for idx in range(self.s_length):
                #     self.s1_clip_range_pos[idx] = 1 * torch.abs(self.s_grad_sums_old[idx] / self.sort_steps)
                #     self.s1_clip_range_neg[idx] = 1 * -torch.abs(self.s_grad_sums_old[idx] / self.sort_steps)
                #     self.s2_clip_range_pos[idx] = 1 * torch.abs(self.s_grad_sums_new[idx] / self.sort_steps)
                #     self.s2_clip_range_neg[idx] = 1 * -torch.abs(self.s_grad_sums_new[idx] / self.sort_steps)
                for idx in range(self.s_length):
                    avg = 3 * torch.abs(self.s_grad_sums_old[idx].cpu() / torch.Tensor(self.sort_steps))
                    
                    greater_than_avg = torch.gt(torch.Tensor(self.s_grad_old_list[idx]), avg)
                    # 找出小于 -avg 的元素
                    less_than_neg_avg = torch.lt(torch.Tensor(self.s_grad_old_list[idx]),-avg)

                    # 结合两个条件
                    combined_conditions = greater_than_avg | less_than_neg_avg
                    count_greater_than_avg = torch.sum(combined_conditions)
                    
                    self.outer_old[idx] = count_greater_than_avg / self.sort_steps
                    
                    avg = 3 * torch.abs((self.s_grad_sums_new[idx]).cpu() / torch.Tensor(self.sort_steps))
                    
                    greater_than_avg = torch.gt(torch.Tensor(self.s_grad_new_list[idx]) , avg)
                    # 找出小于 -avg 的元素
                    less_than_neg_avg = torch.lt(torch.Tensor(self.s_grad_new_list[idx]) , -avg)

                    # 结合两个条件
                    combined_conditions = greater_than_avg | less_than_neg_avg
                    count_greater_than_avg = torch.sum(combined_conditions)
                    
                    self.outer_new[idx] = count_greater_than_avg / self.sort_steps
                self.sort_outer_old = sorted(self.outer_old)
                self.sort_outer_new = sorted(self.outer_new)
                print(self.sort_outer_old)
                    
            
            if self.sum_steps % self.sort_steps == 0:
                
                print(self.sum_steps)
                
                for i in range(self.s_length):
                    print(i, "s1 clip ", self.s1_clip_range_neg[i], self.s1_clip_range_pos[i])
                    print(i, "s2 clip", self.s2_clip_range_neg[i], self.s2_clip_range_pos[i])
            # if self.sum_steps % self.sort_steps == 0:
            #     freeze_length = int(self.freeze_ratio * self.s_length)
            #     self.freeze = random.sample(range(32, self.s_length), k=freeze_length)
            # if self.sum_steps % self.sort_steps == 0:
            #     self.freeze = []
            #     if (self.sum_steps // self.sort_steps) % 2 == 1:
            #         for i in range(0, self.s_length // 2):
            #             self.freeze.append(i)
            #     else:
            #         for i in range(self.s_length // 2, self.s_length):
            #             self.freeze.append(i)       
            # if self.sum_steps % self.sort_steps == 0:
            #     self.freeze = []
            #     if (self.sum_steps // self.sort_steps) % 2 == 0:
            #         for i in range(self.s_length // 2, self.s_length):
            #             self.freeze.append(i)   
            # temp_freeze = self.freeze[:]
            # if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
            #     self.freeze = []
            #     for i in range(self.s_length):
            #         if i not in temp_freeze:
            #         # 计算差的绝对值
            #             equal_to_previous = np.abs(self.p_list[i][:-1] - self.p_list[i][1:]) < 1e-6
                    
            #             # 计算满足条件的元素数量
            #             count_close_elements = np.sum(equal_to_previous)
            #             if count_close_elements / self.sort_steps >= self.freeze_same_threshold:
            #                 self.freeze.append(i)
            #                 print("freeze", i)
            # self.sum_steps += 1
            # 梯度冲突比例大的先冻结
            
            # if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
            #     self.freeze = []
            #     for i in range(self.s_length):
            #         # 计算差的绝对值
            #         equal_to_previous = (self.s_grad_old_list[i] > 0) == (self.s_grad_new_list[i] > 0)

            #         # 计算满足条件的元素数量
            #         count_close_elements = np.sum(equal_to_previous)
            #         self.reverse_ratio_list[i] = count_close_elements / self.sort_steps
            #         print(i, self.reverse_ratio_list[i])
            #     sorted_reverse_list = sorted(self.reverse_ratio_list, reverse=True)
            #     for i in range(self.s_length):
            #         if self.reverse_ratio_list[i] >= sorted_reverse_list[int(self.freeze_reverse_ratio * self.s_length)]:
            #             self.freeze.append(i)
            #             print("freeze", i)
            # if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
            #     self.freeze = []
            #     for i in range(self.s_length):
            #         # 计算差的绝对值
            #         equal_to_previous = (self.s_grad_old_list[i] > 0) != (self.s_grad_new_list[i] > 0)

            #         # 计算满足条件的元素数量
            #         count_close_elements = np.sum(equal_to_previous)
            #         self.reverse_ratio_list[i] = count_close_elements / self.sort_steps
            #         print(i, self.reverse_ratio_list[i])
            #     sorted_reverse_list = sorted(self.reverse_ratio_list, reverse=True)
            #     for i in range(self.s_length):
            #         if self.reverse_ratio_list[i] >= sorted_reverse_list[int(self.freeze_reverse_ratio * self.s_length)]:
            #             self.freeze.append(i)
            #             print("freeze", i)       
            
            
            
            # if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
            #     self.freeze_old = []
            #     self.freeze_new = []
            #     for i in range(self.s_length):
            #         # 计算差的绝对值
            #         equal_to_previous_old = (self.s_grad_old_list[i][:-1] > 0) != (self.s_grad_old_list[i][1:] > 0)
            #         equal_to_previous_new = (self.s_grad_new_list[i][:-1] > 0) != (self.s_grad_new_list[i][1:] > 0)
            #         # 计算满足条件的元素数量
            #         count_close_elements_old = np.sum(equal_to_previous_old)
            #         count_close_elements_new = np.sum(equal_to_previous_new)
            #         self.reverse_ratio_list_old[i] = count_close_elements_old / self.sort_steps
            #         self.reverse_ratio_list_new[i] = count_close_elements_new / self.sort_steps
            #         print(i, self.reverse_ratio_list_old[i],self.reverse_ratio_list_new[i])
            #     sorted_reverse_list_old = sorted(self.reverse_ratio_list_old, reverse=True)
            #     sorted_reverse_list_new = sorted(self.reverse_ratio_list_new, reverse=True)
            #     for i in range(self.s_length):
            #         if self.reverse_ratio_list_old[i] >= sorted_reverse_list_old[int(self.freeze_reverse_ratio * self.s_length)]:
            #             self.freeze_old.append(i)
            #             print("freeze old", i)  
            #         if self.reverse_ratio_list_new[i] >= sorted_reverse_list_new[int(self.freeze_reverse_ratio * self.s_length)]:
            #             self.freeze_new.append(i)
            #             print("freeze new", i)  
            # if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
            #     self.freeze_old = []
            #     self.freeze_new = []
            #     for i in range(self.s_length):
            #         # 计算差的绝对值
            #         equal_to_previous_old = (self.s_grad_old_list[i][:-1] > 0) != (self.s_grad_old_list[i][1:] > 0)
            #         equal_to_previous_new = (self.s_grad_new_list[i][:-1] > 0) != (self.s_grad_new_list[i][1:] > 0)
                    
            #         # 计算满足条件的元素数量
            #         count_close_elements_old = np.sum(equal_to_previous_old)
            #         count_close_elements_new = np.sum(equal_to_previous_new)
            #         self.reverse_ratio_list_old[i] = count_close_elements_old / self.sort_steps
            #         self.reverse_ratio_list_new[i] = count_close_elements_new / self.sort_steps
                    
            #         self.reverse_ratio_list_delta[i] = self.reverse_ratio_list_old[i] - self.reverse_ratio_list_new[i]
            #         print(i, "old new delta", self.reverse_ratio_list_old[i],self.reverse_ratio_list_new[i], self.reverse_ratio_list_delta[i])
            #     sorted_reverse_list_delta = sorted(self.reverse_ratio_list_delta) # 从小到大，最小代表old变化幅度小于new, 应该冻结new, 最大相反
            #     for i in range(self.s_length):
            #         if self.reverse_ratio_list_delta[i] >= sorted_reverse_list_delta[int((1-self.freeze_reverse_ratio_old) * self.s_length)]:
            #             self.freeze_old.append(i)
            #             print("freeze old", i)  
            #         if self.reverse_ratio_list_delta[i] <= sorted_reverse_list_delta[int(self.freeze_reverse_ratio_new * self.s_length)]:
            #             self.freeze_new.append(i)
            #             print("freeze new", i)  
            # if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
            #     self.freeze_old = []
            #     self.freeze_new = []
            #     for i in range(self.s_length):
            #         # 计算差的绝对值
            #         equal_to_previous_old = (self.s_grad_old_list[i][:-1] > 0) != (self.s_grad_old_list[i][1:] > 0)
            #         equal_to_previous_new = (self.s_grad_new_list[i][:-1] > 0) != (self.s_grad_new_list[i][1:] > 0)
                    
            #         # 计算满足条件的元素数量
            #         count_close_elements_old = np.sum(equal_to_previous_old)
            #         count_close_elements_new = np.sum(equal_to_previous_new)
            #         self.reverse_ratio_list_old[i] = count_close_elements_old / self.sort_steps
            #         self.reverse_ratio_list_new[i] = count_close_elements_new / self.sort_steps
                    
            #         self.reverse_ratio_list_delta[i] = self.reverse_ratio_list_old[i] - self.reverse_ratio_list_new[i]
            #         print(i, "old new delta", self.reverse_ratio_list_old[i],self.reverse_ratio_list_new[i], self.reverse_ratio_list_delta[i])
            #     # sorted_reverse_list_delta = sorted(self.reverse_ratio_list_delta) # 从小到大，最小代表old变化幅度小于new, 应该冻结new, 最大相反
            #     for i in range(self.s_length):
            #         if self.reverse_ratio_list_delta[i] >= self.freeze_reverse_threshold_old:
            #             self.freeze_old.append(i)
            #             print("freeze old", i)  
            #         if self.reverse_ratio_list_delta[i] <= -self.freeze_reverse_threshold_new:
            #             self.freeze_new.append(i)
            #             print("freeze new", i)  
            
            
            
            # if self.sum_steps  % self.sort_steps == self.freeze_steps:   
            #     self.freeze_old = []
            temp_freeze = self.freeze_old[:]
            if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
                if self.no_unfreeze == False:
                    self.freeze_old = []
                self.freeze_new = []
                for i in range(self.s_length):
                    # 计算差的绝对值
                    equal_to_previous_old = (self.s_grad_old_list[i][:-1] > 0) != (self.s_grad_old_list[i][1:] > 0)
                    equal_to_previous_new = (self.s_grad_new_list[i][:-1] > 0) != (self.s_grad_new_list[i][1:] > 0)
                    
                    # 计算满足条件的元素数量
                    count_close_elements_old = np.sum(equal_to_previous_old)
                    count_close_elements_new = np.sum(equal_to_previous_new)
                    self.reverse_ratio_list_old[i] = count_close_elements_old / self.sort_steps
                    self.reverse_ratio_list_new[i] = count_close_elements_new / self.sort_steps
                    
                    self.reverse_ratio_list_delta[i] = self.reverse_ratio_list_old[i] - self.reverse_ratio_list_new[i]
                    print(i, "old new delta", self.reverse_ratio_list_old[i],self.reverse_ratio_list_new[i], self.reverse_ratio_list_delta[i])
                # sorted_reverse_list_delta = sorted(self.reverse_ratio_list_delta) # 从小到大，最小代表old变化幅度小于new, 应该冻结new, 最大相反
                for i in range(self.s_length):
                    if self.reverse_freeze:
                        if self.reverse_ratio_list_old[i] > self.freeze_old_ratio:
                            self.freeze_old.append(i)
                            print("freeze old", i)
                    
                    else:    
                        if (self.continue_freeze == False and i not in temp_freeze) or self.continue_freeze == True:
                            if self.reverse_ratio_list_old[i] <= self.freeze_old_ratio:
                                self.freeze_old.append(i)
                                print("freeze old", i)
                # for i in range(self.s_length):
                #     if self.reverse_ratio_list_new[i] <= self.freeze_new_ratio:
                #         self.freeze_new.append(i)
                #         print("freeze new", i)
                #     if self.reverse_ratio_list_old[i] <= 0.3:
                #         if i not in temp_freeze:
                #             if len(temp_freeze) == self.freeze_length:
                #                 max_i = -1
                #                 max_times = -1
                #                 for freeze_idx in temp_freeze:
                #                     if self.freeze_times[freeze_idx] >= max_times:
                #                         max_i = freeze_idx
                #                 self.freeze_times[max_i] = 0
                #                 self.freeze_times[i] = 0
                #                 temp_freeze.remove(max_i)
                #                 temp_freeze.append(i)
                #             else:
                #                 temp_freeze.append(i)
                #                 self.freeze_times[i] = 0
                #     else:
                #         if i in temp_freeze:
                #             temp_freeze.remove(i)
                #             self.freeze_times[i] = 0
                # for i in temp_freeze:
                #     self.freeze_times[i] += 1
                #     self.freeze_old.append(i)
                #     print("freeze old", i)

                            
                                    
                    # if i not in temp_freeze:
                    # if self.reverse_ratio_list_old[i] <= 0.3:
                    #     if i % 2 == 0:
                    #         self.freeze_old.append(i)
                    #         self.freeze_new.append(i)
                    #         print("freeze old", i)
                    #         print("freeze new", i)
                    #     else:
                    #         self.freeze_old.append(i)
                    #         self.freeze_new.append(i)
                    #         print("freeze old", i)
                    #         print("freeze new", i)
                    # else:
                    #     if self.reverse_ratio_list_old[i] > 0.3:
                            
                    #         self.freeze_new.append(i)
                    #         print("freeze new", i)  
                # if self.sum_steps >= 2500:
                # self.freeze_old =  random.sample(range(0, self.s_length, 2), 4)
                    # if self.reverse_ratio_list_delta[i] <= -0.15:
                    #     self.freeze_new.append(i)
                    #     print("freeze new", i)  
                        
                    # if self.reverse_ratio_list_old[i] >= 0.5:
                    #     self.freeze_new.append(i)
                    #     print("freeze new", i)  
                    # # sum_steps一开始以old主导，需要冻结new
                    # # 之后可以交替冻结old和new
                    # if self.sum_steps <= 12000:
                    # if self.reverse_ratio_list_delta[i] <= -0.15:
                    #     self.freeze_new.append(i)
                    #     print("freeze new", i)  
                    # else:
                    #     # pass
                    #     if self.sum_steps % (4 * self.sort_steps) == 0:
                    #         if self.reverse_ratio_list_delta[i] <= -0.15:
                    #             self.freeze_new.append(i)
                    #             print("freeze new", i)  
                    #     else:
                    #         if self.reverse_ratio_list_old[i] <= 0.3:
                    #             self.freeze_old.append(i)
                    #             print("freeze old", i)  
                # self.freeze_old = [1, 100, 102, 104]
            if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
                print(self.sum_steps)
                for i, s_grads in enumerate(self.s_grad_sums_old):
                    print("old", i, s_grads)
                    self.s_grad_sums_old[i] = 0
                for i, s_grads in enumerate(self.s_grad_sums_new):
                    print("new", i, s_grads)
                    self.s_grad_sums_new[i] = 0
                    
                for i in range(self.s_length):
                    print("old real", i, self.s_grad_sums_old_real[i])
                    print("new real", i, self.s_grad_sums_new_real[i])
                    print("old real div scale", i, self.s_grad_sums_old_real[i] / np.mean(self.p_list[i][:]))
                    print("new real div scale", i, self.s_grad_sums_new_real[i] / np.mean(self.p_list[i][:]))
                    self.s_grad_sums_old_real[i] = 0
                    self.s_grad_sums_new_real[i] = 0                
             
            
            # if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
            #     self.freeze = []
                
            #     for i in range(self.s_length):
            #         grad_temp_add_list = self.s_grad_old_list[i] + self.s_grad_new_list[i]
            #         equal_to_previous = (grad_temp_add_list[:-1] > 0) != (grad_temp_add_list[1:] > 0)
            #         count_close_elements = np.sum(equal_to_previous)
            #         self.reverse_ratio_list_add[i] = count_close_elements / self.sort_steps
            #         print(i, "add", self.reverse_ratio_list_add[i])
            #         # 计算差的绝对值
            #         equal_to_previous_old = (self.s_grad_old_list[i][:-1] > 0) != (self.s_grad_old_list[i][1:] > 0)
            #         equal_to_previous_new = (self.s_grad_new_list[i][:-1] > 0) != (self.s_grad_new_list[i][1:] > 0)
                    
            #         计算满足条件的元素数量
            #         count_close_elements_old = np.sum(equal_to_previous_old)
            #         count_close_elements_new = np.sum(equal_to_previous_new)
            #         self.reverse_ratio_list_old[i] = count_close_elements_old / self.sort_steps
            #         self.reverse_ratio_list_new[i] = count_close_elements_new / self.sort_steps
                    
            #         self.reverse_ratio_list_delta[i] = self.reverse_ratio_list_old[i] - self.reverse_ratio_list_new[i]
            #         print(i, "old new delta", self.reverse_ratio_list_old[i],self.reverse_ratio_list_new[i], self.reverse_ratio_list_delta[i])
            #     sorted_reverse_list_delta = sorted(self.reverse_ratio_list_delta) # 从小到大，最小代表old变化幅度小于new, 应该冻结new, 最大相反
            #     for i in range(self.s_length):
            #             if self.reverse_ratio_list_old[i] < 0.35 and i not in temp_freeze:
                            
            #                 self.freeze_old.append(i)
            #                 print("freeze old", i)  
                
            #         if self.reverse_ratio_list_old[i] > 0.5:
            #             self.freeze_new.append(i)
            #             print("freeze new", i)          
            # if self.sum_steps % self.sort_steps == 0 and self.sum_steps != 0:
            #     self.freeze = []
            #     for i in range(self.s_length):
            #         # 计算差的绝对值
            #         equal_to_previous = (self.s_grad_old_list[i] > 0) != (self.s_grad_new_list[i] > 0)

            #         # 计算满足条件的元素数量
            #         count_close_elements = np.sum(equal_to_previous)
            #         self.reverse_ratio_list[i] = count_close_elements / self.sort_steps
            #         print(i, self.reverse_ratio_list[i])
            #     for i in range(self.s_length):
            #         if self.reverse_ratio_list[i] >= self.freeze_reverse_threshold:
            #             self.freeze.append(i)
            #             print("freeze", i)      
            idx = 0
            
            for group in self.q_param_groups:
                for p in group['params']:
                    if p.grad is None: continue
                    self.s_grad_sums_old[idx] += torch.abs(self.state[p]['old_g'])
                    self.s_grad_sums_new[idx] += torch.abs(p.grad)
                    self.s_grad_old_list[idx][self.sum_steps % self.sort_steps] = self.state[p]['old_g'].clone()
                    self.s_grad_new_list[idx][self.sum_steps % self.sort_steps] = p.grad.data.clone()
                 
                    if self.clip_range == -1:
                        # if idx >= 2 and idx <= 84 and idx % 2 == 0:
                        #     if self.sum_steps <= 15000:
                        #         self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'],0, 0)
                        #     else:
                        #         p.grad = torch.clamp(p.grad,0, 0)
                        # pass
                        # out_idx = int((1-self.out_ratio) * self.s_length)
                        # if self.outer_old[idx] > self.sort_outer_old[out_idx]:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], self.s1_clip_range_neg[idx], self.s1_clip_range_pos[idx])
                        # if self.outer_new[idx] > self.sort_outer_new[out_idx]:
                        #     p.grad = torch.clamp(p.grad, self.s2_clip_range_neg[idx], self.s2_clip_range_pos[idx])
                        # if self.outer_old[idx] > self.out_ratio:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], self.s1_clip_range_neg[idx], self.s1_clip_range_pos[idx])
                        # if self.outer_new[idx] > self.out_ratio:
                        #     p.grad = torch.clamp(p.grad, self.s2_clip_range_neg[idx], self.s2_clip_range_pos[idx])
                        
                        
                        # pass
                        # self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], self.s1_clip_range_neg[idx], self.s1_clip_range_pos[idx])
                        
                        # self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], self.s1_clip_range_neg[idx], self.s1_clip_range_pos[idx])
                        # p.grad = torch.clamp(p.grad, self.s2_clip_range_neg[idx], self.s2_clip_range_pos[idx])
                        # if idx in self.clip5:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], self.s1_clip_range_neg[idx], self.s1_clip_range_pos[idx])
                        #     p.grad = torch.clamp(p.grad, self.s2_clip_range_neg[idx], self.s2_clip_range_pos[idx])
                        # self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -self.clip_range, self.clip_range)
                        # p.grad = torch.clamp(p.grad, -self.clip_range, self.clip_range)
                        # if self.state[p]['old_g'] < 0 and p.grad < 0 or self.state[p]['old_g'] > 0 and p.grad > 0: 
                        # if idx in self.freeze:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], 0,0)
                        #     p.grad = torch.clamp(p.grad, 0,0)
                        # if self.sum_steps < 2000:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -50,50)
                        #     p.grad = torch.clamp(p.grad, -50,50)
                        if self.sum_steps >= self.st_freeze_steps:
                            if idx in self.freeze_old:
                                self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], 0,0)
                            if idx in self.freeze_new:
                                p.grad = torch.clamp(p.grad, 0,0)
                        # else:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -10,10)
                        #     p.grad = torch.clamp(p.grad, -10,10)
                        # if idx in self.clip5:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -50,50)
                        #     p.grad = torch.clamp(p.grad, -100,100)
                        # else:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -10,10)
                        #     p.grad = torch.clamp(p.grad, -10,10)
                        # if idx in self.clip10:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -100,100)
                        #     p.grad = torch.clamp(p.grad, -30,30)
                        # if idx in self.clip15:
                        #     self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -15,15)
                        #     p.grad = torch.clamp(p.grad, -15,15)
                        # self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -20,20)
                        # p.grad *= 2
                        # p.grad = torch.clamp(p.grad, -10,10)
                        
                        
                            
                        
                    elif self.clip_range != 0:
                        self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -self.clip_range, self.clip_range)
                        p.grad = torch.clamp(p.grad, -self.clip_range, self.clip_range)
                    
                    self.p_list[idx][self.sum_steps % self.sort_steps] = p.data
                    grad_type = type(p.grad)
                    
                    self.s_grad_sums_old_real[idx] += self.state[p]['old_g']
                    self.s_grad_sums_new_real[idx] += p.grad
                    
                    
                    
                    if self.s_grad == 'both':
                        sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                        
                    elif self.s_grad == 'first':
                        sam_grad = self.state[p]['old_g'] - p.grad 
                    elif self.s_grad == 'second':
                        sam_grad = 0
                    elif self.s_grad == 'random':
                        sam_grad = self.state[p]['old_g'] - p.grad 
                        # 生成一个0到1之间的随机浮点数
                        random_number = random.random()
                        if random_number > 0.5:
                            sam_grad += p.grad * 0.5  
                    
                    


                    # print(self.state[p]['old_g'], p.grad)
                    # sam_grad = 0
                    
                    # if (self.state[p]['old_g'] > 0 and p.grad > 0) or (self.state[p]['old_g'] < 0 and p.grad < 0):
                    #     # sam_grad = -p.grad + torch.sqrt(self.state[p]['old_g'] * p.grad)
                    #     sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                    # else:
                    #     # sam_grad = -p.grad + self.state[p]['old_g']
                    #     sam_grad = -p.grad
                    # sam_grad = self.state[p]['old_g'] - p.grad
                    # print("old", self.state[p]['old_g'], "new", p.grad)
                    p.grad.data.add_(sam_grad)
                    # if idx >= 2 and idx <= 84 and idx % 2 == 0:
                    #     p.grad.data = p.grad.data * 2
                    # if self.sum_steps >= self.sort_steps:
                    #     p.grad.data = torch.clamp(p.grad.data, -self.ema_s_clip[idx], self.ema_s_clip[idx])
                    idx += 1
            self.sum_steps = self.sum_steps + 1
            self.clip_ratio -= 1 / self.sum_steps 
            # if self.sum_steps % self.sort_steps == 0:
            #     for sort_idx in range(self.s_length):
            #         sort_list = np.sort(self.s_grad_old_list[sort_idx])
                    
            #         neg_idx = int(self.clip_ratio * self.sort_steps)
            #         pos_idx = int((1-self.clip_ratio) * self.sort_steps)
            #         self.s1_clip_range_pos[sort_idx] = sort_list[pos_idx]
            #         self.s1_clip_range_neg[sort_idx] = sort_list[neg_idx]
            #         sort_list = np.sort(self.s_grad_new_list[sort_idx])
            #         self.s2_clip_range_pos[sort_idx] = sort_list[pos_idx]
            #         self.s2_clip_range_neg[sort_idx] = sort_list[neg_idx]
                    
                
        if self.optimizer_q_fix_5000steps != None:
            for group in self.q_param_groups_fix_5000steps:
                for p in group['params']:
                    if p.grad is None: continue
                    grad_type = type(p.grad)
                    if self.clip_range != 0:
                        self.state[p]['old_g'] = torch.clamp(self.state[p]['old_g'], -self.clip_range, self.clip_range)
                        p.grad = torch.clamp(p.grad, -self.clip_range * 2, self.clip_range * 2)
                        # p.grad = torch.clamp(p.grad, -self.clip_range, self.clip_range)
                    if self.s_grad == 'both':
                        # sam_grad = self.state[p]['old_g'] - p.grad * 0.5
                        sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                    elif self.s_grad == 'first':
                        sam_grad = self.state[p]['old_g'] - p.grad 
                    elif self.s_grad == 'second':
                        sam_grad = 0
                    elif self.s_grad == 'random':
                        sam_grad = self.state[p]['old_g'] - p.grad 
                        # 生成一个0到1之间的随机浮点数
                        random_number = random.random()
                        if random_number > 0.5:
                            sam_grad += p.grad * 0.5  
                    


                    # print(self.state[p]['old_g'], p.grad)
                    # sam_grad = 0
                    
                    # if (self.state[p]['old_g'] > 0 and p.grad > 0) or (self.state[p]['old_g'] < 0 and p.grad < 0):
                    #     # sam_grad = -p.grad + torch.sqrt(self.state[p]['old_g'] * p.grad)
                    #     sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                    # else:
                    #     # sam_grad = -p.grad + self.state[p]['old_g']
                    #     sam_grad = -p.grad
                    # sam_grad = self.state[p]['old_g'] - p.grad
                    # print("old", self.state[p]['old_g'], "new", p.grad)
                    p.grad.data.add_(sam_grad)
        # print(self.old_grads)
        # print("____________")
        # print(self.new_grads)
        
        
        # # toy
        # tensor_old = torch.unsqueeze(self.old_grads[1], 0)
        # tensor_new = torch.unsqueeze(self.new_grads[1], 0)
        # vector_old = torch.cat((self.old_grads[0], tensor_old), dim=1)
        # vector_new = torch.cat((self.new_grads[0], tensor_new), dim=1)
        # dot_product = torch.dot(vector_old.view(-1), vector_new.view(-1))
        # # 计算向量的模
        # norm1 = torch.norm(vector_old)
        # norm2 = torch.norm(vector_new)

        # # 计算余弦值
        # cosine_sim = dot_product / (norm1 * norm2)
        # self.cos_sims.append(cosine_sim.item()) 

        # cos_sim = torch.cosine_similarity(torch.cat((self.old_grads[0], tensor_old), dim=1), torch.cat((self.new_grads[0], tensor_new), dim=1))
        
        # print("cos_sim:", cosine_sim)
    @torch.no_grad()
    def _sync_grad(self):
        if torch.distributed.is_initialized(): # synchronize final gardients
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None: continue
                    if self.manual_average:
                        torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
                        world_size = torch.distributed.get_world_size()
                        p.grad.div_(float(world_size))
                    else:
                        torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
            if self.q_base_optimizer != None:
                for group in self.q_param_groups:
                    for p in group['params']:
                        if p.grad is None: continue
                        if self.manual_average:
                            torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
                            world_size = torch.distributed.get_world_size()
                            p.grad.div_(float(world_size))
                        else:
                            torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
            if self.optimizer_q_fix_5000steps != None:
                for group in self.q_param_groups_fix_5000steps:
                    for p in group['params']:
                        if p.grad is None: continue
                        if self.manual_average:
                            torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
                            world_size = torch.distributed.get_world_size()
                            p.grad.div_(float(world_size))
                        else:
                            torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
        return

    @torch.no_grad()
    def _grad_norm(self, by=None, weight_adaptive=False):
        #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        if not by:

            norm = torch.norm(
                    torch.stack([
                        ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )

        else:

            norm = torch.norm(
                torch.stack([
                    ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
                    for group in self.param_groups for p in group["params"]
                    if p.grad is not None
                ]),
                p=2
            )

        return norm

    # def norm(tensor_list: List[torch.tensor], p=2):
    #     """Compute p-norm for tensor list"""
    #     return torch.cat([x.flatten() for x in tensor_list]).norm(p)

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
        
    def maybe_no_sync(self):
        if torch.distributed.is_initialized():
            return self.model.no_sync()
        else:
            return contextlib.ExitStack()

    @torch.no_grad()
    def set_closure(self, loss_fn, inputs, targets, **kwargs):
        # create self.forward_backward_func, which is a function such that
        # self.forward_backward_func() automatically performs forward and backward passes.
        # This function does not take any arguments, and the inputs and targets data
        # should be pre-set in the definition of partial-function

        def get_grad():
            self.base_optimizer.zero_grad()
            if self.q_base_optimizer != None:
                self.q_base_optimizer.zero_grad()
            if self.optimizer_q_fix_5000steps != None:
                self.optimizer_q_fix_5000steps.zero_grad()
            with torch.enable_grad():
                outputs = self.model(inputs)
                # print("1 2", outputs.shape, targets.shape)
                # print(outputs)
                loss = loss_fn(outputs, targets, **kwargs)
            loss_value = loss.data.clone().detach()
            loss.backward()

            return outputs, loss_value

        self.forward_backward_func = get_grad

    @torch.no_grad()
    def step(self, closure=None):
        self.step_cnt += 1
        if closure:
            get_grad = closure
        else:
            get_grad = self.forward_backward_func

        with self.maybe_no_sync():
            # get gradient
            # 获得量化反量化之后运算的梯度
            outputs, loss_value = get_grad()

            # perturb weights
            self.perturb_weights(rho=self.rho_t)

            # disable running stats for second pass
            disable_running_stats(self.model)

            # get gradient at perturbed weights
            get_grad()

            # decompose and get new update direction
            self.gradient_decompose(self.alpha)

            # unperturb
            self.unperturb()
            
        # synchronize gradients across workers
        self._sync_grad()    
        # for name, p in self.model.named_parameters():
        #     if 'layer1.0.conv1.weight' in name:
        #         print("Gradient1", p.grad[0])
        #     elif 'layer2.0.conv2.weight' in name:
        #         print("Gradient2", p.grad[0])
        #     elif 'layer4.0.conv3.weight' in name:
        #         print("Gradient3", p.grad[0])
        # update with new directions
        # if self.update_q % 2 == 1:
        #     self.base_optimizer.step()
        # else:
        #     self.q_base_optimizer.step()
        # if self.q_base_optimizer != None and self.update_q  > 1000:
        #     self.q_base_optimizer.step()
        # self.update_q += 1
        
        self.base_optimizer.step()
        if self.fix_alter_steps == 0:
            if self.fix_reverse == False:
                if self.q_base_optimizer != None:
                    self.q_base_optimizer.step()
                # if self.step_cnt >= 5000 and self.optimizer_q_fix_5000steps != None:
                #     self.optimizer_q_fix_5000steps.step()
                if self.optimizer_q_fix_5000steps != None:
                    self.optimizer_q_fix_5000steps.step()
            else:
                if self.q_base_optimizer != None and self.step_cnt >= 5000:
                    self.q_base_optimizer.step()
                if self.optimizer_q_fix_5000steps != None:
                    self.optimizer_q_fix_5000steps.step()
        else:
            if self.q_base_optimizer != None:
                self.q_base_optimizer.step()
            if self.optimizer_q_fix_5000steps != None and self.step_cnt % self.fix_alter_steps == 0:
                self.optimizer_q_fix_5000steps.step()
        # enable running stats
        enable_running_stats(self.model)

        return outputs, loss_value


# # toy
# # 原始sagm + 量化s
# class SAGM(torch.optim.Optimizer):
#     def __init__(self, params, base_optimizer, model, alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', q_base_optimizer = None, **kwargs):
#         defaults = dict(adaptive=adaptive, **kwargs)
#         super(SAGM, self).__init__(params, defaults)
#         self.model = model
#         self.base_optimizer = base_optimizer
#         self.param_groups = self.base_optimizer.param_groups
#         self.q_base_optimizer = q_base_optimizer
#         if self.q_base_optimizer != None:
#             self.q_param_groups = self.q_base_optimizer.param_groups
#         self.adaptive = adaptive
#         self.rho_scheduler = rho_scheduler
#         self.perturb_eps = perturb_eps
#         self.alpha = alpha
#         self.old_grads = []
#         self.new_grads = []
#         self.cos_sims = []
#         self.weight_list = []
#         self.s_list = []
#         self.update_q = 0
#         # initialize self.rho_t
#         self.update_rho_t()
        
#         # set up reduction for gradient across workers
#         if grad_reduce.lower() == 'mean':
#             if hasattr(ReduceOp, 'AVG'):
#                 self.grad_reduce = ReduceOp.AVG
#                 self.manual_average = False
#             else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
#                 self.grad_reduce = ReduceOp.SUM
#                 self.manual_average = True
#         elif grad_reduce.lower() == 'sum':
#             self.grad_reduce = ReduceOp.SUM
#             self.manual_average = False
#         else:
#             raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
    
#     @torch.no_grad()
#     def update_rho_t(self):
#         self.rho_t = self.rho_scheduler.step()
#         return self.rho_t

#     @torch.no_grad()
#     def perturb_weights(self, rho=0.0):
#         self.old_grads = []
        
#         grad_norm = self._grad_norm( weight_adaptive = self.adaptive )
#         for group in self.param_groups:
#             # print(group)
#             scale = (rho / (grad_norm + self.perturb_eps) - self.alpha)

#             for p in group["params"]:
                
#                 self.weight_list.append(p.data.clone())
#                 if p.grad is None: continue
#                 self.state[p]["old_g"] = p.grad.data.clone()
#                 self.old_grads.append(self.state[p]["old_g"])
#                 e_w = p.grad * scale.to(p)
#                 if self.adaptive:
#                     e_w *= torch.pow(p, 2)
#                 p.add_(e_w)  # climb to the local maximum "w + e(w)"
#                 self.state[p]['e_w'] = e_w
                
#         if self.q_base_optimizer != None:
#             for group in self.q_param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     self.state[p]["old_g"] = p.grad.data.clone()
#                     # print("s", p)
#                     self.s_list.append(p.data.clone())
#     @torch.no_grad()
#     def unperturb(self):
#         # 需要修改，修改成原始梯度，一开始传入的应该是量化之后反量化的梯度
#         for group in self.param_groups:
#             for p in group['params']:
#                 if 'e_w' in self.state[p].keys():
#                     p.data.sub_(self.state[p]['e_w'])

#     @torch.no_grad()
#     def gradient_decompose(self, alpha=0.0):
#         self.new_grads = []
#         # 这里需要好好看看
#         for group in self.param_groups:
#             for p in group['params']:
#                 if p.grad is None: continue
#                 self.new_grads.append(p.grad.data.clone())
#                 sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
#                 p.grad.data.add_(sam_grad)
#         if self.q_base_optimizer != None:
#             for group in self.q_param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
#                     p.grad.data.add_(sam_grad)
#         # print(self.old_grads)
#         # print("____________")
#         # print(self.new_grads)
        
        
#         # # toy
#         # tensor_old = torch.unsqueeze(self.old_grads[1], 0)
#         # tensor_new = torch.unsqueeze(self.new_grads[1], 0)
#         # vector_old = torch.cat((self.old_grads[0], tensor_old), dim=1)
#         # vector_new = torch.cat((self.new_grads[0], tensor_new), dim=1)
#         # dot_product = torch.dot(vector_old.view(-1), vector_new.view(-1))
#         # # 计算向量的模
#         # norm1 = torch.norm(vector_old)
#         # norm2 = torch.norm(vector_new)

#         # # 计算余弦值
#         # cosine_sim = dot_product / (norm1 * norm2)
#         # self.cos_sims.append(cosine_sim.item()) 

#         # cos_sim = torch.cosine_similarity(torch.cat((self.old_grads[0], tensor_old), dim=1), torch.cat((self.new_grads[0], tensor_new), dim=1))
        
#         # print("cos_sim:", cosine_sim)
#     @torch.no_grad()
#     def _sync_grad(self):
#         if torch.distributed.is_initialized(): # synchronize final gardients
#             for group in self.param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     if self.manual_average:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#                         world_size = torch.distributed.get_world_size()
#                         p.grad.div_(float(world_size))
#                     else:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#             if self.q_base_optimizer != None:
#                 for group in self.q_param_groups:
#                     for p in group['params']:
#                         if p.grad is None: continue
#                         if self.manual_average:
#                             torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#                             world_size = torch.distributed.get_world_size()
#                             p.grad.div_(float(world_size))
#                         else:
#                             torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#         return

#     @torch.no_grad()
#     def _grad_norm(self, by=None, weight_adaptive=False):
#         #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
#         if not by:

#             norm = torch.norm(
#                     torch.stack([
#                         ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
#                         for group in self.param_groups for p in group["params"]
#                         if p.grad is not None
#                     ]),
#                     p=2
#                )

#         else:

#             norm = torch.norm(
#                 torch.stack([
#                     ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
#                     for group in self.param_groups for p in group["params"]
#                     if p.grad is not None
#                 ]),
#                 p=2
#             )

#         return norm

#     # def norm(tensor_list: List[torch.tensor], p=2):
#     #     """Compute p-norm for tensor list"""
#     #     return torch.cat([x.flatten() for x in tensor_list]).norm(p)

#     def load_state_dict(self, state_dict):
#         super().load_state_dict(state_dict)
#         self.base_optimizer.param_groups = self.param_groups
        
#     def maybe_no_sync(self):
#         if torch.distributed.is_initialized():
#             return self.model.no_sync()
#         else:
#             return contextlib.ExitStack()

#     @torch.no_grad()
#     def set_closure(self, loss_fn, inputs, targets, **kwargs):
#         # create self.forward_backward_func, which is a function such that
#         # self.forward_backward_func() automatically performs forward and backward passes.
#         # This function does not take any arguments, and the inputs and targets data
#         # should be pre-set in the definition of partial-function

#         def get_grad():
#             self.base_optimizer.zero_grad()
#             if self.q_base_optimizer != None:
#                 self.q_base_optimizer.zero_grad()
#             with torch.enable_grad():
#                 outputs = self.model(inputs)
#                 # print("1 2", outputs.shape, targets.shape)
#                 # print(outputs)
#                 loss = loss_fn(outputs, targets, **kwargs)
#             loss_value = loss.data.clone().detach()
#             loss.backward()

#             return outputs, loss_value

#         self.forward_backward_func = get_grad

#     @torch.no_grad()
#     def step(self, closure=None):

#         if closure:
#             get_grad = closure
#         else:
#             get_grad = self.forward_backward_func

#         with self.maybe_no_sync():
#             # get gradient
#             # 获得量化反量化之后运算的梯度
#             outputs, loss_value = get_grad()

#             # perturb weights
#             self.perturb_weights(rho=self.rho_t)

#             # disable running stats for second pass
#             disable_running_stats(self.model)

#             # get gradient at perturbed weights
#             get_grad()

#             # decompose and get new update direction
#             self.gradient_decompose(self.alpha)

#             # unperturb
#             self.unperturb()
            
#         # synchronize gradients across workers
#         self._sync_grad()    
#         # for name, p in self.model.named_parameters():
#         #     if 'layer1.0.conv1.weight' in name:
#         #         print("Gradient1", p.grad[0])
#         #     elif 'layer2.0.conv2.weight' in name:
#         #         print("Gradient2", p.grad[0])
#         #     elif 'layer4.0.conv3.weight' in name:
#         #         print("Gradient3", p.grad[0])
#         # update with new directions
#         if self.update_q % 2 == 1:
#             self.base_optimizer.step()
#         else:
#             self.q_base_optimizer.step()
#         # if self.q_base_optimizer != None and self.update_q  > 1000:
#         #     self.q_base_optimizer.step()
#         self.update_q += 1
        
#         # self.base_optimizer.step()
#         # if self.q_base_optimizer != None:
#         #     self.q_base_optimizer.step()

#         # enable running stats
#         enable_running_stats(self.model)

#         return outputs, loss_value


# # 反量化的sagm + 量化s
# class SAGM(torch.optim.Optimizer):
#     def __init__(self, params, base_optimizer, model, alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', q_base_optimizer = None, **kwargs):
#         defaults = dict(adaptive=adaptive, **kwargs)
#         super(SAGM, self).__init__(params, defaults)
#         self.model = model
#         self.base_optimizer = base_optimizer
#         self.param_groups = self.base_optimizer.param_groups
#         self.q_base_optimizer = q_base_optimizer
#         self.adaptive = adaptive
#         self.rho_scheduler = rho_scheduler
#         self.perturb_eps = perturb_eps
#         self.alpha = alpha
        
#         # initialize self.rho_t
#         self.update_rho_t()
        
#         # set up reduction for gradient across workers
#         if grad_reduce.lower() == 'mean':
#             if hasattr(ReduceOp, 'AVG'):
#                 self.grad_reduce = ReduceOp.AVG
#                 self.manual_average = False
#             else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
#                 self.grad_reduce = ReduceOp.SUM
#                 self.manual_average = True
#         elif grad_reduce.lower() == 'sum':
#             self.grad_reduce = ReduceOp.SUM
#             self.manual_average = False
#         else:
#             raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
    
#     @torch.no_grad()
#     def update_rho_t(self):
#         self.rho_t = self.rho_scheduler.step()
#         return self.rho_t

#     @torch.no_grad()
#     def perturb_weights(self, rho=0.0):
#         grad_norm = self._grad_norm( weight_adaptive = self.adaptive )
#         for group in self.param_groups:
#             # scale是不变的，不管现在的权重到底是不是量化过的
#             scale = (rho / (grad_norm + self.perturb_eps) - self.alpha)

#             for p in group["params"]:
#                 if p.grad is None: continue
#                 self.state[p]["old_g"] = p.grad.data.clone()
                
#                 e_w = p.grad * scale.to(p)
#                 if self.adaptive:
#                     e_w *= torch.pow(p, 2)
                
#                 p.add_(e_w)  # climb to the local maximum "w + e(w)"
#                 self.state[p]['e_w'] = e_w
                
#     @torch.no_grad()
#     def unperturb(self):
#         # 需要修改，修改成原始梯度，一开始传入的应该是量化之后反量化的梯度
#         for group in self.param_groups:
#             for p in group['params']:
#                 if 'e_w' in self.state[p].keys():
#                     p.data.sub_(self.state[p]['e_w'])

#     @torch.no_grad()
#     def gradient_decompose(self, alpha=0.0):
#         # 这里需要好好看看
#         for group in self.param_groups:
#             for p in group['params']:
#                 if p.grad is None: continue
#                 sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
#                 p.grad.data.add_(sam_grad)

#     @torch.no_grad()
#     def _sync_grad(self):
#         if torch.distributed.is_initialized(): # synchronize final gardients
#             for group in self.param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     if self.manual_average:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#                         world_size = torch.distributed.get_world_size()
#                         p.grad.div_(float(world_size))
#                     else:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#         return

#     @torch.no_grad()
#     def _grad_norm(self, by=None, weight_adaptive=False):
#         #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
#         if not by:

#             norm = torch.norm(
#                     torch.stack([
#                         ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
#                         for group in self.param_groups for p in group["params"]
#                         if p.grad is not None
#                     ]),
#                     p=2
#                )

#         else:

#             norm = torch.norm(
#                 torch.stack([
#                     ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
#                     for group in self.param_groups for p in group["params"]
#                     if p.grad is not None
#                 ]),
#                 p=2
#             )

#         return norm

#     # def norm(tensor_list: List[torch.tensor], p=2):
#     #     """Compute p-norm for tensor list"""
#     #     return torch.cat([x.flatten() for x in tensor_list]).norm(p)

#     def load_state_dict(self, state_dict):
#         super().load_state_dict(state_dict)
#         self.base_optimizer.param_groups = self.param_groups
        
#     def maybe_no_sync(self):
#         if torch.distributed.is_initialized():
#             return self.model.no_sync()
#         else:
#             return contextlib.ExitStack()

#     @torch.no_grad()
#     def set_closure(self, loss_fn, inputs, targets, **kwargs):
#         # create self.forward_backward_func, which is a function such that
#         # self.forward_backward_func() automatically performs forward and backward passes.
#         # This function does not take any arguments, and the inputs and targets data
#         # should be pre-set in the definition of partial-function

#         def get_grad():
#             self.base_optimizer.zero_grad()
#             if self.q_base_optimizer != None:
#                 self.q_base_optimizer.zero_grad()
#             with torch.enable_grad():
#                 outputs = self.model(inputs)
#                 loss = loss_fn(outputs, targets, **kwargs)
#             loss_value = loss.data.clone().detach()
#             loss.backward()

#             return outputs, loss_value

#         self.forward_backward_func = get_grad

#     @torch.no_grad()
#     def step(self, closure=None):

#         if closure:
#             get_grad = closure
#         else:
#             get_grad = self.forward_backward_func

#         with self.maybe_no_sync():
#             # get gradient
#             # 获得量化反量化之后运算的梯度
#             outputs, loss_value = get_grad()

#             # perturb weights
#             self.perturb_weights(rho=self.rho_t)

#             # disable running stats for second pass
#             disable_running_stats(self.model)

#             # get gradient at perturbed weights
#             get_grad()

#             # decompose and get new update direction
#             self.gradient_decompose(self.alpha)

#             # unperturb
#             self.unperturb()
            
#         # synchronize gradients across workers
#         self._sync_grad()    
#         # for name, p in self.model.named_parameters():
#         #     if 'layer1.0.conv1.weight' in name:
#         #         print("Gradient1", p.grad[0])
#         #     elif 'layer2.0.conv2.weight' in name:
#         #         print("Gradient2", p.grad[0])
#         #     elif 'layer4.0.conv3.weight' in name:
#         #         print("Gradient3", p.grad[0])
#         # update with new directions
#         self.base_optimizer.step()
#         if self.q_base_optimizer != None:
#             self.q_base_optimizer.step()

#         # enable running stats
#         enable_running_stats(self.model)

#         return outputs, loss_value