import torch
from typing import Callable, Optional
import math

class SABCD(torch.optim.Optimizer):
    """
    Implement Sharpness-Aware Block Coordinate Descent (SABCD) optimizer
    
    Combines characteristics of SAM (Sharpness-Aware Minimization) and block coordinate descent:
    1. Identify the most sensitive parameter subsets based on gradient momentum magnitude
    2. Calculate worst-case perturbations to find flat regions of the loss surface
    3. Update only sensitive parameters while keeping others unchanged to mitigate catastrophic forgetting
    
    Args:
        params: Parameters to optimize
        base_optimizer: Base optimizer class (e.g., Adam, SGD, etc.)
        rho: Perturbation radius
        alpha: First momentum coefficient (similar to beta1 in Adam)
        beta: Second momentum coefficient (similar to beta2 in Adam)
        selection_percent: Percentage of parameters to update (between 0 and 1)
        adaptive: Whether to use adaptive perturbation calculation (True for ASAM, False for SAM)
        epsilon: Numerical stability constant
        **kwargs: Additional parameters passed to the base optimizer
    """
    def __init__(
        self, 
        params, 
        base_optimizer: torch.optim.Optimizer, 
        rho: float = 0.05, 
        alpha: float = 0.9,
        beta: float = 0.999,
        selection_percent: float = 0.3,
        adaptive: bool = False,
        epsilon: float = 1e-8,
        **kwargs
    ):
        assert 0.0 <= rho, f"Invalid rho, should be non-negative: {rho}"
        assert 0.0 <= selection_percent <= 1.0, f"selection_percent must be between 0 and 1, got: {selection_percent}"
        
        defaults = dict(
            rho=rho, 
            adaptive=adaptive, 
            alpha=alpha,
            beta=beta,
            selection_percent=selection_percent,
            epsilon=epsilon,
            **kwargs
        )
        super(SABCD, self).__init__(params, defaults)
        
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)
        
        # Initialize momentum buffers for each parameter
        for group in self.param_groups:
            for p in group['params']:
                self.state[p]['momentum_buffer'] = torch.zeros_like(p.data)
                self.state[p]['velocity_buffer'] = torch.zeros_like(p.data)
                self.state[p]['step'] = 0
    
    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        """
        Calculate sensitive parameters and apply perturbations - First step of SAM
        
        1. Select the most sensitive parameters to update based on momentum magnitude
        2. Calculate worst-case perturbations for these parameters
        3. Apply perturbations to parameters
        """
        all_params_with_grad = []
        
        # 1. Update momentum buffers and collect parameters with non-zero gradients
        for group in self.param_groups:
            alpha = group['alpha']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # Update momentum estimate
                if 'momentum_buffer' not in self.state[p]:
                    self.state[p]['momentum_buffer'] = torch.zeros_like(p.data)
                
                mom = self.state[p]['momentum_buffer']
                mom.mul_(alpha).add_(p.grad, alpha=1.0 - alpha)
                
                # Collect parameters with gradients
                all_params_with_grad.append((p, mom.abs().clone()))
        
        if not all_params_with_grad:
            return
            
        # 2. Find threshold to select top-k sensitive parameters
        all_magnitudes = torch.cat([m.flatten() for _, m in all_params_with_grad])
        num_params = len(all_magnitudes)
        k = max(1, int(self.defaults['selection_percent'] * num_params))
        
        # Find the k-th largest value as threshold - use topk for efficiency
        if k < num_params:
            threshold = torch.topk(all_magnitudes, k, sorted=False)[0].min()
        else:
            threshold = -float('inf')  # If k equals num_params, select all
        
        # Mark sensitive parameters
        for p, magnitude in all_params_with_grad:
            # In mask, True values are parameters to update
            self.state[p]['sensitive_mask'] = magnitude >= threshold
        
        # 3. Calculate gradient norm, considering only sensitive parameters
        grad_norm = self._grad_norm_sensitive()
        
        # 4. Apply perturbations to sensitive parameters
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            
            for p in group["params"]:
                if p.grad is None:
                    continue
                
                if 'sensitive_mask' not in self.state[p]:
                    continue
                
                mask = self.state[p]['sensitive_mask']
                self.state[p]["old_p"] = p.data.clone()
                
                # Calculate perturbation, applied only to sensitive parameters
                if group["adaptive"]:
                    # ASAM: Adaptive perturbation size
                    e_w = torch.zeros_like(p)
                    e_w[mask] = (torch.pow(p, 2)[mask] * p.grad[mask]) * scale
                else:
                    # SAM: Fixed perturbation size
                    e_w = torch.zeros_like(p)
                    e_w[mask] = p.grad[mask] * scale
                
                # Apply perturbation (w += e_w)
                p.add_(e_w)
        
        # Zero gradients if needed
        if zero_grad:
            self.zero_grad()
    
    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        """
        Restore parameters and apply optimization step - Second step of SAM
        
        1. Restore parameters to values before perturbation
        2. Update sensitive parameters using Adam-style update rules
        3. Keep other parameters unchanged
        """
        # Restore parameters
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or "old_p" not in self.state[p]:
                    continue
                p.data = self.state[p]["old_p"]
        
        # Update sensitive parameters
        for group in self.param_groups:
            # Get hyperparameters
            alpha = group['alpha']
            beta = group['beta']
            epsilon = group['epsilon']
            lr = group['lr']
            weight_decay = group.get('weight_decay', 0)
            
            for p in group["params"]:
                if p.grad is None:
                    continue
                
                # Skip parameters without sensitive mask
                if 'sensitive_mask' not in self.state[p]:
                    continue
                
                # Get sensitive mask
                mask = self.state[p]['sensitive_mask']
                
                # Skip if no sensitive parameters
                if not mask.any():
                    continue
                
                # Update step count
                self.state[p]['step'] += 1
                step = self.state[p]['step']
                
                # Get momentum and velocity buffers
                mom = self.state[p]['momentum_buffer']
                vel = self.state[p]['velocity_buffer']
                
                # Weight decay (applied only to sensitive parameters)
                if weight_decay != 0:
                    p.grad.data[mask] = p.grad.data[mask].add(p.data[mask], alpha=weight_decay)
                
                # Update velocity buffer (similar to second moment estimate in Adam)
                vel.mul_(beta).addcmul_(p.grad, p.grad, value=1.0 - beta)
                
                # Calculate bias corrections
                bias_correction1 = 1 - alpha ** step
                bias_correction2 = 1 - beta ** step
                
                # Calculate adaptive learning rate
                denom = (vel.sqrt() / math.sqrt(bias_correction2)).add_(epsilon)
                
                # Update parameters (only sensitive parameters)
                step_size = lr / bias_correction1
                p.data[mask] = p.data[mask].addcdiv_(mom[mask], denom[mask], value=-step_size)
        
        # Zero gradients if needed
        if zero_grad:
            self.zero_grad()
    
    @torch.no_grad()
    def step(self, closure: Optional[Callable] = None):
        """
        Perform single optimization step
        """
        assert closure is not None, "SABCD requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # Ensure closure enables gradients
        
        self.first_step(zero_grad=True)
        closure()
        self.second_step()
        
        return closure()
    
    def _grad_norm_sensitive(self) -> torch.Tensor:
        """
        Calculate gradient norm for sensitive parameters
        
        Returns:
            L2 norm of gradients for sensitive parameters
        """
        norm_sqr = 0.0
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                
                if 'sensitive_mask' not in self.state[p]:
                    continue
                
                mask = self.state[p]['sensitive_mask']
                if not mask.any():
                    continue
                
                if group["adaptive"]:
                    # ASAM: Adaptive gradient norm calculation
                    norm_sqr += torch.sum(torch.pow(p.data[mask] * p.grad[mask], 2))
                else:
                    # SAM: Standard gradient norm calculation
                    norm_sqr += torch.sum(torch.pow(p.grad[mask], 2))
        
        return torch.sqrt(norm_sqr)