import torch
from .util import enable_running_stats, disable_running_stats
import contextlib
from torch.distributed import ReduceOp
# # 原始sagm + 量化s
# class SAGM(torch.optim.Optimizer):
#     def __init__(self, params, base_optimizer, model, alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', q_base_optimizer = None, **kwargs):
#         defaults = dict(adaptive=adaptive, **kwargs)
#         super(SAGM, self).__init__(params, defaults)
#         self.model = model
#         self.base_optimizer = base_optimizer
#         self.param_groups = self.base_optimizer.param_groups
#         self.q_base_optimizer = q_base_optimizer
#         if self.q_base_optimizer != None:
#             self.q_param_groups = self.q_base_optimizer.param_groups
#         self.adaptive = adaptive
#         self.rho_scheduler = rho_scheduler
#         self.perturb_eps = perturb_eps
#         self.alpha = alpha
#         self.update_q = 0
#         # initialize self.rho_t
#         self.update_rho_t()
        
#         # set up reduction for gradient across workers
#         if grad_reduce.lower() == 'mean':
#             if hasattr(ReduceOp, 'AVG'):
#                 self.grad_reduce = ReduceOp.AVG
#                 self.manual_average = False
#             else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
#                 self.grad_reduce = ReduceOp.SUM
#                 self.manual_average = True
#         elif grad_reduce.lower() == 'sum':
#             self.grad_reduce = ReduceOp.SUM
#             self.manual_average = False
#         else:
#             raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
    
#     @torch.no_grad()
#     def update_rho_t(self):
#         self.rho_t = self.rho_scheduler.step()
#         return self.rho_t

#     @torch.no_grad()
#     def perturb_weights(self, rho=0.0):
        
#         grad_norm = self._grad_norm( weight_adaptive = self.adaptive )
#         for group in self.param_groups:
#             # print(group)
#             scale = (rho / (grad_norm + self.perturb_eps) - self.alpha)

#             for p in group["params"]:
                
#                 if p.grad is None: continue
#                 self.state[p]["old_g"] = p.grad.data.clone()
#                 e_w = p.grad * scale.to(p)
#                 if self.adaptive:
#                     e_w *= torch.pow(p, 2)
#                 p.add_(e_w)  # climb to the local maximum "w + e(w)"
#                 self.state[p]['e_w'] = e_w
                
#         if self.q_base_optimizer != None:
#             for group in self.q_param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     self.state[p]["old_g"] = p.grad.data.clone()
#                     # print("s", p)

#     @torch.no_grad()
#     def unperturb(self):
#         # 需要修改，修改成原始梯度，一开始传入的应该是量化之后反量化的梯度
#         for group in self.param_groups:
#             for p in group['params']:
#                 if 'e_w' in self.state[p].keys():
#                     p.data.sub_(self.state[p]['e_w'])

#     @torch.no_grad()
#     def gradient_decompose(self, alpha=0.0):
#         # 这里需要好好看看
#         for group in self.param_groups:
#             for p in group['params']:
#                 if p.grad is None: continue
#                 sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
#                 p.grad.data.add_(sam_grad)
#         if self.q_base_optimizer != None:
#             for group in self.q_param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     # sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                    
#                     # if (self.state[p]['old_g'] > 0 and p.grad > 0) or (self.state[p]['old_g'] < 0 and p.grad < 0):
#                     #     # sam_grad = -p.grad + torch.sqrt(self.state[p]['old_g'] * p.grad)
#                     #     sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
#                     # else:
#                     #     # sam_grad = -p.grad + self.state[p]['old_g']
#                     #     sam_grad = -p.grad
#                     sam_grad = self.state[p]['old_g'] - p.grad
#                     # print("old", self.state[p]['old_g'], "new", p.grad)
#                     p.grad.data.add_(sam_grad)
#         # print(self.old_grads)
#         # print("____________")
#         # print(self.new_grads)
        
        
#         # # toy
#         # tensor_old = torch.unsqueeze(self.old_grads[1], 0)
#         # tensor_new = torch.unsqueeze(self.new_grads[1], 0)
#         # vector_old = torch.cat((self.old_grads[0], tensor_old), dim=1)
#         # vector_new = torch.cat((self.new_grads[0], tensor_new), dim=1)
#         # dot_product = torch.dot(vector_old.view(-1), vector_new.view(-1))
#         # # 计算向量的模
#         # norm1 = torch.norm(vector_old)
#         # norm2 = torch.norm(vector_new)

#         # # 计算余弦值
#         # cosine_sim = dot_product / (norm1 * norm2)
#         # self.cos_sims.append(cosine_sim.item()) 

#         # cos_sim = torch.cosine_similarity(torch.cat((self.old_grads[0], tensor_old), dim=1), torch.cat((self.new_grads[0], tensor_new), dim=1))
        
#         # print("cos_sim:", cosine_sim)
#     @torch.no_grad()
#     def _sync_grad(self):
#         if torch.distributed.is_initialized(): # synchronize final gardients
#             for group in self.param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     if self.manual_average:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#                         world_size = torch.distributed.get_world_size()
#                         p.grad.div_(float(world_size))
#                     else:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#             if self.q_base_optimizer != None:
#                 for group in self.q_param_groups:
#                     for p in group['params']:
#                         if p.grad is None: continue
#                         if self.manual_average:
#                             torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#                             world_size = torch.distributed.get_world_size()
#                             p.grad.div_(float(world_size))
#                         else:
#                             torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#         return

#     @torch.no_grad()
#     def _grad_norm(self, by=None, weight_adaptive=False):
#         #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
#         if not by:

#             norm = torch.norm(
#                     torch.stack([
#                         ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
#                         for group in self.param_groups for p in group["params"]
#                         if p.grad is not None
#                     ]),
#                     p=2
#                )

#         else:

#             norm = torch.norm(
#                 torch.stack([
#                     ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
#                     for group in self.param_groups for p in group["params"]
#                     if p.grad is not None
#                 ]),
#                 p=2
#             )

#         return norm

#     # def norm(tensor_list: List[torch.tensor], p=2):
#     #     """Compute p-norm for tensor list"""
#     #     return torch.cat([x.flatten() for x in tensor_list]).norm(p)

#     def load_state_dict(self, state_dict):
#         super().load_state_dict(state_dict)
#         self.base_optimizer.param_groups = self.param_groups
        
#     def maybe_no_sync(self):
#         if torch.distributed.is_initialized():
#             return self.model.no_sync()
#         else:
#             return contextlib.ExitStack()

#     @torch.no_grad()
#     def set_closure(self, loss_fn, inputs, targets, **kwargs):
#         # create self.forward_backward_func, which is a function such that
#         # self.forward_backward_func() automatically performs forward and backward passes.
#         # This function does not take any arguments, and the inputs and targets data
#         # should be pre-set in the definition of partial-function

#         def get_grad():
#             self.base_optimizer.zero_grad()
#             if self.q_base_optimizer != None:
#                 self.q_base_optimizer.zero_grad()
#             with torch.enable_grad():
#                 outputs = self.model(inputs)
#                 # print("1 2", outputs.shape, targets.shape)
#                 # print(outputs)
#                 loss = loss_fn(outputs, targets, **kwargs)
#             loss_value = loss.data.clone().detach()
#             loss.backward()

#             return outputs, loss_value

#         self.forward_backward_func = get_grad

#     @torch.no_grad()
#     def step(self, closure=None):

#         if closure:
#             get_grad = closure
#         else:
#             get_grad = self.forward_backward_func

#         with self.maybe_no_sync():
#             # get gradient
#             # 获得量化反量化之后运算的梯度
#             outputs, loss_value = get_grad()

#             # perturb weights
#             self.perturb_weights(rho=self.rho_t)

#             # disable running stats for second pass
#             disable_running_stats(self.model)

#             # get gradient at perturbed weights
#             get_grad()

#             # decompose and get new update direction
#             self.gradient_decompose(self.alpha)

#             # unperturb
#             self.unperturb()
            
#         # synchronize gradients across workers
#         self._sync_grad()    
#         # for name, p in self.model.named_parameters():
#         #     if 'layer1.0.conv1.weight' in name:
#         #         print("Gradient1", p.grad[0])
#         #     elif 'layer2.0.conv2.weight' in name:
#         #         print("Gradient2", p.grad[0])
#         #     elif 'layer4.0.conv3.weight' in name:
#         #         print("Gradient3", p.grad[0])
#         # update with new directions
#         # if self.update_q % 2 == 1:
#         #     self.base_optimizer.step()
#         # else:
#         #     self.q_base_optimizer.step()
#         # if self.q_base_optimizer != None and self.update_q  > 1000:
#         #     self.q_base_optimizer.step()
#         # self.update_q += 1
        
#         self.base_optimizer.step()
#         if self.q_base_optimizer != None:
#             self.q_base_optimizer.step()

#         # enable running stats
#         enable_running_stats(self.model)

#         return outputs, loss_value


# toy
# 原始sagm + 量化s
class SAGM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, model, alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', q_base_optimizer = None, **kwargs):
        defaults = dict(adaptive=adaptive, **kwargs)
        super(SAGM, self).__init__(params, defaults)
        self.model = model
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.q_base_optimizer = q_base_optimizer
        if self.q_base_optimizer != None:
            self.q_param_groups = self.q_base_optimizer.param_groups
        self.adaptive = adaptive
        self.rho_scheduler = rho_scheduler
        self.perturb_eps = perturb_eps
        self.alpha = alpha
        self.old_grads = []
        self.new_grads = []
        self.cos_sims = []
        self.weight_list = []
        self.s_list = []
        self.update_q = 0
        self.cnt_opposite = 0
        # initialize self.rho_t
        self.cosine_similarity = 0
        self.update_rho_t()
        
        # set up reduction for gradient across workers
        if grad_reduce.lower() == 'mean':
            if hasattr(ReduceOp, 'AVG'):
                self.grad_reduce = ReduceOp.AVG
                self.manual_average = False
            else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
                self.grad_reduce = ReduceOp.SUM
                self.manual_average = True
        elif grad_reduce.lower() == 'sum':
            self.grad_reduce = ReduceOp.SUM
            self.manual_average = False
        else:
            raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
    
    @torch.no_grad()
    def update_rho_t(self):
        self.rho_t = self.rho_scheduler.step()
        return self.rho_t

    @torch.no_grad()
    def perturb_weights(self, rho=0.0):
        self.old_grads = []
        
        grad_norm = self._grad_norm( weight_adaptive = self.adaptive )
        for group in self.param_groups:
            # print(group)
            scale = (rho / (grad_norm + self.perturb_eps) - self.alpha)

            for p in group["params"]:
                
                self.weight_list.append(p.data.clone())
                if p.grad is None: continue
                self.state[p]["old_g"] = p.grad.data.clone()
                self.old_grads.append(self.state[p]["old_g"])
                e_w = p.grad * scale.to(p)
                if self.adaptive:
                    e_w *= torch.pow(p, 2)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]['e_w'] = e_w
                
        if self.q_base_optimizer != None:
            for group in self.q_param_groups:
                for p in group['params']:
                    if p.grad is None: continue
                    self.state[p]["old_g"] = p.grad.data.clone()
                    # print("s", p)
                    self.s_list.append(p.data.clone())
    @torch.no_grad()
    def unperturb(self):
        # 需要修改，修改成原始梯度，一开始传入的应该是量化之后反量化的梯度
        for group in self.param_groups:
            for p in group['params']:
                if 'e_w' in self.state[p].keys():
                    p.data.sub_(self.state[p]['e_w'])

    @torch.no_grad()
    def gradient_decompose(self, alpha=0.0):
        self.new_grads = []
        # 这里需要好好看看
        
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                # if  (self.state[p]['old_g'] < 0 and p.grad > 0) or (self.state[p]['old_g'] > 0 and p.grad < 0):
                #     print("old", self.state[p]['old_g'], "new", p.grad)
                self.new_grads.append(p.grad.data.clone())
                # print(self.state[p]['old_g'], p.grad)
                tensor1 = self.state[p]['old_g'].clone().view(-1)
                tensor2 = p.grad.clone().view(-1)
                dot_product = torch.dot(tensor1, tensor2)

                # 计算范数
                norm_tensor1 = torch.norm(tensor1)
                norm_tensor2 = torch.norm(tensor2)

                # 计算余弦相似度
                self.cosine_similarity += dot_product / (norm_tensor1 * norm_tensor2)
                
                # print(self.cosine_similarity)
                sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                # sam_grad = 0
                p.grad.data.add_(sam_grad)
        if self.q_base_optimizer != None:
            for group in self.q_param_groups:
                for p in group['params']:
                    if p.grad is None: continue
                    
                    if  (self.state[p]['old_g'] < 0 and p.grad > 0) or (self.state[p]['old_g'] > 0 and p.grad < 0):
                        self.cnt_opposite += 1
                        # print("old", self.state[p]['old_g'], "new", p.grad)
                        # print(self.cnt_opposite)
                    # if  (torch.abs(self.state[p]['old_g']) >  torch.abs(p.grad) * 5):
                    #     sam_grad = self.state[p]['old_g'] - p.grad 
                    # elif (torch.abs(self.state[p]['old_g']) * 5 <  torch.abs(p.grad)):
                    #     sam_grad = 0
                    # else:
                    #     # sam_grad = self.state[p]['old_g'] - p.grad
                    #     sam_grad = 0
                        # sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                    sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
                    # sam_grad = self.state[p]['old_g'] - p.grad
                    # sam_grad = 0
                    p.grad.data.add_(sam_grad)
        # print(self.old_grads)
        # print("____________")
        # print(self.new_grads)
        
        
        # # toy
        # tensor_old = torch.unsqueeze(self.old_grads[1], 0)
        # tensor_new = torch.unsqueeze(self.new_grads[1], 0)
        # vector_old = torch.cat((self.old_grads[0], tensor_old), dim=1)
        # vector_new = torch.cat((self.new_grads[0], tensor_new), dim=1)
        # dot_product = torch.dot(vector_old.view(-1), vector_new.view(-1))
        # # 计算向量的模
        # norm1 = torch.norm(vector_old)
        # norm2 = torch.norm(vector_new)

        # # 计算余弦值
        # cosine_sim = dot_product / (norm1 * norm2)
        # self.cos_sims.append(cosine_sim.item()) 

        # cos_sim = torch.cosine_similarity(torch.cat((self.old_grads[0], tensor_old), dim=1), torch.cat((self.new_grads[0], tensor_new), dim=1))
        
        # print("cos_sim:", cosine_sim)
    @torch.no_grad()
    def _sync_grad(self):
        if torch.distributed.is_initialized(): # synchronize final gardients
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None: continue
                    if self.manual_average:
                        torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
                        world_size = torch.distributed.get_world_size()
                        p.grad.div_(float(world_size))
                    else:
                        torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
            if self.q_base_optimizer != None:
                for group in self.q_param_groups:
                    for p in group['params']:
                        if p.grad is None: continue
                        if self.manual_average:
                            torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
                            world_size = torch.distributed.get_world_size()
                            p.grad.div_(float(world_size))
                        else:
                            torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
        return

    @torch.no_grad()
    def _grad_norm(self, by=None, weight_adaptive=False):
        #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        if not by:

            norm = torch.norm(
                    torch.stack([
                        ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )

        else:

            norm = torch.norm(
                torch.stack([
                    ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
                    for group in self.param_groups for p in group["params"]
                    if p.grad is not None
                ]),
                p=2
            )

        return norm

    # def norm(tensor_list: List[torch.tensor], p=2):
    #     """Compute p-norm for tensor list"""
    #     return torch.cat([x.flatten() for x in tensor_list]).norm(p)

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
        
    def maybe_no_sync(self):
        if torch.distributed.is_initialized():
            return self.model.no_sync()
        else:
            return contextlib.ExitStack()

    @torch.no_grad()
    def set_closure(self, loss_fn, inputs, targets, **kwargs):
        # create self.forward_backward_func, which is a function such that
        # self.forward_backward_func() automatically performs forward and backward passes.
        # This function does not take any arguments, and the inputs and targets data
        # should be pre-set in the definition of partial-function

        def get_grad():
            self.base_optimizer.zero_grad()
            if self.q_base_optimizer != None:
                self.q_base_optimizer.zero_grad()
            with torch.enable_grad():
                outputs = self.model(inputs)
                # print("1 2", outputs.shape, targets.shape)
                # print(outputs)
                loss = loss_fn(outputs, targets, **kwargs)
            loss_value = loss.data.clone().detach()
            loss.backward()

            return outputs, loss_value

        self.forward_backward_func = get_grad

    @torch.no_grad()
    def step(self, closure=None):

        if closure:
            get_grad = closure
        else:
            get_grad = self.forward_backward_func

        with self.maybe_no_sync():
            # get gradient
            # 获得量化反量化之后运算的梯度
            outputs, loss_value = get_grad()

            # perturb weights
            self.perturb_weights(rho=self.rho_t)

            # disable running stats for second pass
            disable_running_stats(self.model)

            # get gradient at perturbed weights
            get_grad()

            # decompose and get new update direction
            self.gradient_decompose(self.alpha)

            # unperturb
            self.unperturb()
            
        # synchronize gradients across workers
        self._sync_grad()    
        # for name, p in self.model.named_parameters():
        #     if 'layer1.0.conv1.weight' in name:
        #         print("Gradient1", p.grad[0])
        #     elif 'layer2.0.conv2.weight' in name:
        #         print("Gradient2", p.grad[0])
        #     elif 'layer4.0.conv3.weight' in name:
        #         print("Gradient3", p.grad[0])
        # update with new directions
        if self.update_q % 2 == 1:
            self.base_optimizer.step()
        else:
            self.q_base_optimizer.step()
        # if self.q_base_optimizer != None and self.update_q  > 1000:
        #     self.q_base_optimizer.step()
        self.update_q += 1
        
        # self.base_optimizer.step()
        # if self.q_base_optimizer != None:
        #     self.q_base_optimizer.step()

        # enable running stats
        enable_running_stats(self.model)

        return outputs, loss_value


# # 反量化的sagm + 量化s
# class SAGM(torch.optim.Optimizer):
#     def __init__(self, params, base_optimizer, model, alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', q_base_optimizer = None, **kwargs):
#         defaults = dict(adaptive=adaptive, **kwargs)
#         super(SAGM, self).__init__(params, defaults)
#         self.model = model
#         self.base_optimizer = base_optimizer
#         self.param_groups = self.base_optimizer.param_groups
#         self.q_base_optimizer = q_base_optimizer
#         self.adaptive = adaptive
#         self.rho_scheduler = rho_scheduler
#         self.perturb_eps = perturb_eps
#         self.alpha = alpha
        
#         # initialize self.rho_t
#         self.update_rho_t()
        
#         # set up reduction for gradient across workers
#         if grad_reduce.lower() == 'mean':
#             if hasattr(ReduceOp, 'AVG'):
#                 self.grad_reduce = ReduceOp.AVG
#                 self.manual_average = False
#             else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
#                 self.grad_reduce = ReduceOp.SUM
#                 self.manual_average = True
#         elif grad_reduce.lower() == 'sum':
#             self.grad_reduce = ReduceOp.SUM
#             self.manual_average = False
#         else:
#             raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
    
#     @torch.no_grad()
#     def update_rho_t(self):
#         self.rho_t = self.rho_scheduler.step()
#         return self.rho_t

#     @torch.no_grad()
#     def perturb_weights(self, rho=0.0):
#         grad_norm = self._grad_norm( weight_adaptive = self.adaptive )
#         for group in self.param_groups:
#             # scale是不变的，不管现在的权重到底是不是量化过的
#             scale = (rho / (grad_norm + self.perturb_eps) - self.alpha)

#             for p in group["params"]:
#                 if p.grad is None: continue
#                 self.state[p]["old_g"] = p.grad.data.clone()
                
#                 e_w = p.grad * scale.to(p)
#                 if self.adaptive:
#                     e_w *= torch.pow(p, 2)
                
#                 p.add_(e_w)  # climb to the local maximum "w + e(w)"
#                 self.state[p]['e_w'] = e_w
                
#     @torch.no_grad()
#     def unperturb(self):
#         # 需要修改，修改成原始梯度，一开始传入的应该是量化之后反量化的梯度
#         for group in self.param_groups:
#             for p in group['params']:
#                 if 'e_w' in self.state[p].keys():
#                     p.data.sub_(self.state[p]['e_w'])

#     @torch.no_grad()
#     def gradient_decompose(self, alpha=0.0):
#         # 这里需要好好看看
#         for group in self.param_groups:
#             for p in group['params']:
#                 if p.grad is None: continue
#                 sam_grad = self.state[p]['old_g'] * 0.5 - p.grad * 0.5
#                 p.grad.data.add_(sam_grad)

#     @torch.no_grad()
#     def _sync_grad(self):
#         if torch.distributed.is_initialized(): # synchronize final gardients
#             for group in self.param_groups:
#                 for p in group['params']:
#                     if p.grad is None: continue
#                     if self.manual_average:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#                         world_size = torch.distributed.get_world_size()
#                         p.grad.div_(float(world_size))
#                     else:
#                         torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
#         return

#     @torch.no_grad()
#     def _grad_norm(self, by=None, weight_adaptive=False):
#         #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
#         if not by:

#             norm = torch.norm(
#                     torch.stack([
#                         ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
#                         for group in self.param_groups for p in group["params"]
#                         if p.grad is not None
#                     ]),
#                     p=2
#                )

#         else:

#             norm = torch.norm(
#                 torch.stack([
#                     ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
#                     for group in self.param_groups for p in group["params"]
#                     if p.grad is not None
#                 ]),
#                 p=2
#             )

#         return norm

#     # def norm(tensor_list: List[torch.tensor], p=2):
#     #     """Compute p-norm for tensor list"""
#     #     return torch.cat([x.flatten() for x in tensor_list]).norm(p)

#     def load_state_dict(self, state_dict):
#         super().load_state_dict(state_dict)
#         self.base_optimizer.param_groups = self.param_groups
        
#     def maybe_no_sync(self):
#         if torch.distributed.is_initialized():
#             return self.model.no_sync()
#         else:
#             return contextlib.ExitStack()

#     @torch.no_grad()
#     def set_closure(self, loss_fn, inputs, targets, **kwargs):
#         # create self.forward_backward_func, which is a function such that
#         # self.forward_backward_func() automatically performs forward and backward passes.
#         # This function does not take any arguments, and the inputs and targets data
#         # should be pre-set in the definition of partial-function

#         def get_grad():
#             self.base_optimizer.zero_grad()
#             if self.q_base_optimizer != None:
#                 self.q_base_optimizer.zero_grad()
#             with torch.enable_grad():
#                 outputs = self.model(inputs)
#                 loss = loss_fn(outputs, targets, **kwargs)
#             loss_value = loss.data.clone().detach()
#             loss.backward()

#             return outputs, loss_value

#         self.forward_backward_func = get_grad

#     @torch.no_grad()
#     def step(self, closure=None):

#         if closure:
#             get_grad = closure
#         else:
#             get_grad = self.forward_backward_func

#         with self.maybe_no_sync():
#             # get gradient
#             # 获得量化反量化之后运算的梯度
#             outputs, loss_value = get_grad()

#             # perturb weights
#             self.perturb_weights(rho=self.rho_t)

#             # disable running stats for second pass
#             disable_running_stats(self.model)

#             # get gradient at perturbed weights
#             get_grad()

#             # decompose and get new update direction
#             self.gradient_decompose(self.alpha)

#             # unperturb
#             self.unperturb()
            
#         # synchronize gradients across workers
#         self._sync_grad()    
#         # for name, p in self.model.named_parameters():
#         #     if 'layer1.0.conv1.weight' in name:
#         #         print("Gradient1", p.grad[0])
#         #     elif 'layer2.0.conv2.weight' in name:
#         #         print("Gradient2", p.grad[0])
#         #     elif 'layer4.0.conv3.weight' in name:
#         #         print("Gradient3", p.grad[0])
#         # update with new directions
#         self.base_optimizer.step()
#         if self.q_base_optimizer != None:
#             self.q_base_optimizer.step()

#         # enable running stats
#         enable_running_stats(self.model)

#         return outputs, loss_value