import torch

class SAM(torch.optim.Optimizer):
    """
    Sharpness-Aware Minimization Optimizer
    
    Unlike SABCD, SAM applies perturbation to all parameters, not just sensitive ones
    
    Args:
        params: Parameters to optimize
        base_optimizer: Base optimizer class (e.g., Adam, SGD, etc.)
        rho: Perturbation radius
        alpha: First momentum coefficient (similar to beta1 in Adam)
        beta: Second momentum coefficient (similar to beta2 in Adam)
        adaptive: Whether to use adaptive perturbation calculation (True for ASAM, False for SAM)
        epsilon: Numerical stability constant
    """
    def __init__(
        self, 
        params, 
        base_optimizer: torch.optim.Optimizer, 
        rho: float = 0.05, 
        alpha: float = 0.9,
        beta: float = 0.999,
        adaptive: bool = False,
        epsilon: float = 1e-8,
        **kwargs
    ):
        assert 0.0 <= rho, f"Invalid rho, should be non-negative: {rho}"
        
        defaults = dict(
            rho=rho, 
            adaptive=adaptive, 
            alpha=alpha,
            beta=beta,
            epsilon=epsilon,
            **kwargs
        )
        super(SAM, self).__init__(params, defaults)
        
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)
        
        # Initialize momentum cache for each parameter
        for group in self.param_groups:
            for p in group['params']:
                self.state[p]['momentum_buffer'] = torch.zeros_like(p.data)
                self.state[p]['velocity_buffer'] = torch.zeros_like(p.data)
                self.state[p]['step'] = 0
    
    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        """Compute and apply the worst-case perturbation"""
        # Update momentum buffer
        for group in self.param_groups:
            alpha = group['alpha']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # Update momentum estimate
                if 'momentum_buffer' not in self.state[p]:
                    self.state[p]['momentum_buffer'] = torch.zeros_like(p.data)
                
                mom = self.state[p]['momentum_buffer']
                mom.mul_(alpha).add_(p.grad, alpha=1.0 - alpha)
        
        # Calculate gradient norm - using all parameters
        grad_norm = self._grad_norm()
        
        # Apply perturbation
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            
            for p in group["params"]:
                if p.grad is None:
                    continue
                
                # Save original parameters
                self.state[p]["old_p"] = p.data.clone()
                
                # Compute perturbation
                if group["adaptive"]:
                    e_w = (torch.pow(p, 2) * p.grad) * scale
                else:
                    e_w = p.grad * scale
                
                # Apply perturbation
                p.add_(e_w)
        
        if zero_grad:
            self.zero_grad()
    
    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        """Restore parameters and apply optimization step"""
        # Restore parameters
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or "old_p" not in self.state[p]:
                    continue
                p.data = self.state[p]["old_p"]
        
        # Update parameters - using all parameters
        for group in self.param_groups:
            alpha = group['alpha']
            beta = group['beta']
            epsilon = group['epsilon']
            lr = group['lr']
            weight_decay = group.get('weight_decay', 0)
            
            for p in group["params"]:
                if p.grad is None:
                    continue
                
                self.state[p]['step'] += 1
                step = self.state[p]['step']
                
                mom = self.state[p]['momentum_buffer']
                vel = self.state[p]['velocity_buffer']
                
                if weight_decay != 0:
                    p.grad.data = p.grad.data.add(p.data, alpha=weight_decay)
                
                vel.mul_(beta).addcmul_(p.grad, p.grad, value=1.0 - beta)
                
                bias_correction1 = 1 - alpha ** step
                bias_correction2 = 1 - beta ** step
                
                denom = (vel.sqrt() / torch.sqrt(torch.tensor(bias_correction2))).add_(epsilon)
                
                step_size = lr / bias_correction1
                p.data = p.data.addcdiv_(mom, denom, value=-step_size)
        
        if zero_grad:
            self.zero_grad()
    
    def step(self, closure=None):
        """Execute optimization step"""
        if closure is not None:
            # If closure is provided, handle in original way
            self.first_step(zero_grad=True)
            loss = closure()
            self.second_step()
            return loss
        else:
            # If no closure provided, only execute base optimizer step
            self.base_optimizer.step()
            print("Warning: SAM step called without closure, this should not happen.")
            return None
    
    def _grad_norm(self):
        """Calculate gradient norm of all parameters"""
        norm_sqr = 0.0
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                
                if group["adaptive"]:
                    norm_sqr += torch.sum(torch.pow(p.data * p.grad, 2))
                else:
                    norm_sqr += torch.sum(torch.pow(p.grad, 2))
        
        return torch.sqrt(norm_sqr)