import torch

class SABCD(torch.optim.Optimizer):
    """
    Sharpness-Aware Block Coordinate Descent Optimizer
    Combining SAM (Sharpness-Aware Minimization) and Block Coordinate Descent
    
    Args:
        params: Parameters to optimize
        base_optimizer: Base optimizer class (e.g., Adam, SGD, etc.)
        rho: Perturbation radius
        alpha: First momentum coefficient (similar to beta1 in Adam)
        beta: Second momentum coefficient (similar to beta2 in Adam)
        selection_percent: Percentage of parameters to update (between 0 and 1)
        adaptive: Whether to use adaptive perturbation calculation (True for ASAM, False for SAM)
        epsilon: Numerical stability constant
    """
    def __init__(
        self, 
        params, 
        base_optimizer: torch.optim.Optimizer, 
        rho: float = 0.05, 
        alpha: float = 0.9,
        beta: float = 0.999,
        selection_percent: float = 0.3,
        adaptive: bool = False,
        epsilon: float = 1e-8,
        **kwargs
    ):
        assert 0.0 <= rho, f"Invalid rho, should be non-negative: {rho}"
        assert 0.0 <= selection_percent <= 1.0, f"selection_percent must be between 0 and 1, got: {selection_percent}"
        
        defaults = dict(
            rho=rho, 
            adaptive=adaptive, 
            alpha=alpha,
            beta=beta,
            selection_percent=selection_percent,
            epsilon=epsilon,
            **kwargs
        )
        super(SABCD, self).__init__(params, defaults)
        
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)
        
        # Initialize momentum cache for each parameter
        for group in self.param_groups:
            for p in group['params']:
                self.state[p]['momentum_buffer'] = torch.zeros_like(p.data)
                self.state[p]['velocity_buffer'] = torch.zeros_like(p.data)
                self.state[p]['step'] = 0
    
    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        """Compute and apply the worst-case perturbation"""
        all_params_with_grad = []
        
        # Update momentum buffer and collect parameters
        for group in self.param_groups:
            alpha = group['alpha']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # Update momentum estimate
                if 'momentum_buffer' not in self.state[p]:
                    self.state[p]['momentum_buffer'] = torch.zeros_like(p.data)
                
                mom = self.state[p]['momentum_buffer']
                mom.mul_(alpha).add_(p.grad, alpha=1.0 - alpha)
                
                # Collect parameters with gradients
                all_params_with_grad.append((p, mom.abs().clone()))
        
        if not all_params_with_grad:
            return
            
        # Find threshold to select sensitive parameters
        all_magnitudes = torch.cat([m.flatten() for _, m in all_params_with_grad])
        num_params = len(all_magnitudes)
        k = max(1, int(self.defaults['selection_percent'] * num_params))
        
        if k < num_params:
            threshold = torch.topk(all_magnitudes, k, sorted=False)[0].min()
        else:
            threshold = -float('inf')
        
        # Mark sensitive parameters
        for p, magnitude in all_params_with_grad:
            self.state[p]['sensitive_mask'] = magnitude >= threshold
        
        # Calculate gradient norm
        grad_norm = self._grad_norm_sensitive()
        
        # Apply perturbation
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            
            for p in group["params"]:
                if p.grad is None or 'sensitive_mask' not in self.state[p]:
                    continue
                
                mask = self.state[p]['sensitive_mask']
                self.state[p]["old_p"] = p.data.clone()
                
                # Compute perturbation
                if group["adaptive"]:
                    e_w = torch.zeros_like(p)
                    e_w[mask] = (torch.pow(p, 2)[mask] * p.grad[mask]) * scale
                else:
                    e_w = torch.zeros_like(p)
                    e_w[mask] = p.grad[mask] * scale
                
                # Apply perturbation
                p.add_(e_w)
        
        if zero_grad:
            self.zero_grad()
    
    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        """Restore parameters and apply optimization step"""
        # Restore parameters
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or "old_p" not in self.state[p]:
                    continue
                p.data = self.state[p]["old_p"]
        
        # Update sensitive parameters
        for group in self.param_groups:
            alpha = group['alpha']
            beta = group['beta']
            epsilon = group['epsilon']
            lr = group['lr']
            weight_decay = group.get('weight_decay', 0)
            
            for p in group["params"]:
                if p.grad is None or 'sensitive_mask' not in self.state[p]:
                    continue
                
                mask = self.state[p]['sensitive_mask']
                if not mask.any():
                    continue
                
                self.state[p]['step'] += 1
                step = self.state[p]['step']
                
                mom = self.state[p]['momentum_buffer']
                vel = self.state[p]['velocity_buffer']
                
                if weight_decay != 0:
                    p.grad.data[mask] = p.grad.data[mask].add(p.data[mask], alpha=weight_decay)
                
                vel.mul_(beta).addcmul_(p.grad, p.grad, value=1.0 - beta)
                
                bias_correction1 = 1 - alpha ** step
                bias_correction2 = 1 - beta ** step
                
                denom = (vel.sqrt() / torch.sqrt(torch.tensor(bias_correction2))).add_(epsilon)
                
                step_size = lr / bias_correction1
                p.data[mask] = p.data[mask].addcdiv_(mom[mask], denom[mask], value=-step_size)
        
        if zero_grad:
            self.zero_grad()
    
    def step(self, closure=None):
        """Execute optimization step - modified version does not require closure"""
        if closure is not None:
            # If closure is provided, handle in original way
            self.first_step(zero_grad=True)
            loss = closure()
            self.second_step()
            return loss
        else:
            # If no closure provided, only execute base optimizer step
            # This part should not be called, as SABCDWrapper handles the two-step logic
            self.base_optimizer.step()
            print("Error!!!!!: SABCD step called without closure, this should not happen.")
            return None
    
    def _grad_norm_sensitive(self):
        """Calculate gradient norm of sensitive parameters"""
        norm_sqr = 0.0
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or 'sensitive_mask' not in self.state[p]:
                    continue
                
                mask = self.state[p]['sensitive_mask']
                if not mask.any():
                    continue
                
                if group["adaptive"]:
                    norm_sqr += torch.sum(torch.pow(p.data[mask] * p.grad[mask], 2))
                else:
                    norm_sqr += torch.sum(torch.pow(p.grad[mask], 2))
        
        return torch.sqrt(norm_sqr)