import torch
import torch.nn as nn
import torch_pruning as tp
import torch.nn.functional as F
        
class GradGapPruner(tp.pruner.MetaPruner):
    def compute_grad_gap(self, member_grads, nonmember_grads):
        grad_gaps = {}
        for name, grad in member_grads.items():
            grad_gap = torch.abs(grad - nonmember_grads[name])
            # grad_gap = (grad - nonmember_grads[name])
            grad_gaps[name] = grad_gap
        return grad_gaps
    
    def regularize(self, model, grad_gaps, reg_weight, adaptive_strength=5,args=None, **kwargs):
        
        if args.reg_norm=="l1":
            for name, m in model.named_modules():
                if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Conv2d, nn.Linear)):
                    max_reg_factor = args.reg_clamp
                    min_value = torch.finfo(torch.float32).eps
                    mean_value = max(grad_gaps[name].mean(), min_value)
                    adaptive_reg = reg_weight * torch.clamp(
                        (grad_gaps[name] / mean_value) ** adaptive_strength,
                        min=0, max=max_reg_factor
                    )
                    
                    m.weight.grad.data.add_(adaptive_reg * torch.sign(m.weight.data))
                    
        elif args.reg_norm=="l2":
            for name, m in model.named_modules():
                if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Conv2d, nn.Linear)):
                    max_reg_factor = args.reg_clamp 
                    min_value = torch.finfo(torch.float32).eps  
                    mean_value = max(grad_gaps[name].mean(), min_value)
                    adaptive_reg = reg_weight * torch.clamp(
                        (grad_gaps[name] / mean_value) ** adaptive_strength,
                        min=0, max=max_reg_factor
                    )

                    m.weight.grad.data.add_(adaptive_reg *2*m.weight.data)
                