import torch

class HamOptimizerWrapper(torch.optim.Optimizer):
    def __init__(self, optimizer, max_weight_norm=1000.0, max_grad_norm=20.0):
        """
        A wrapper around a standard optimizer that updates only the magnitude
        of parameters based on the gradient, while preserving the sign of each parameter.
        This is based on the HAM (Helinger-Kantorovic Minimization) paper.
        Args:
            optimizer (torch.optim.Optimizer): The base optimizer to wrap.
            max_weight_norm (float): Maximum allowed norm for weight tensors.
        """
        self.optimizer = optimizer
        self.max_weight_norm = max_weight_norm
        self.max_grad_norm = max_grad_norm

    def __getattr__(self, name):
        """
        Delegate attribute access to the wrapped optimizer.
        """
        if name == "optimizer":  # Prevent infinite recursion
            return super().__getattr__(name)
        return getattr(self.optimizer, name)
    
    def step(self, closure=None):
        """Perform optimization step with NaN protection"""
        # Run original optimizer step
        self.optimizer.step(closure)
        
        nan_detected = False
        
        for group in self.optimizer.param_groups:
            for param in group['params']:
                if param.grad is None:
                    continue
                    
                # Check for NaNs in the current parameters and gradients
                if torch.isnan(param.data).any():
                    print(f"NaN detected in parameter data: shape={param.shape}")
                    nan_detected = True
                    continue
                    
                if torch.isnan(param.grad).any():
                    print(f"NaN detected in gradient: shape={param.shape}")
                    nan_detected = True
                    continue
                
                # Only apply HAM update to weights with more than 2 dimensions
                is_weight = len(param.shape) > 2
                if is_weight:
                    # Store original param data for safety
                    orig_data = param.data.clone()
                    
                    h1 = 50
                    h2 = 1e-3
                    lr = group['lr']
                    
                    # Calculate and check each term separately
                    sign_term = torch.sign(param.data)
                    base_exponent = -h1 * sign_term * param.grad - h2
                    
                    # Very conservative clamping to prevent extreme values
                    exponent = torch.clamp(base_exponent * lr, -1.0, 1.0)
                    
                    # Check intermediate values for NaN
                    if torch.isnan(exponent).any():
                        print("NaN detected in exponent calculation")
                        nan_detected = True
                        continue
                    
                    # Apply update with safety check
                    update_factor = torch.exp(exponent)
                    mask = (param.data != 0)
                    param.data[mask] = param.data[mask] * update_factor[mask]   
                    
                    # Check for NaNs after update and revert if needed
                    if torch.isnan(param.data).any():
                        print("NaN detected after parameter update - reverting")
                        param.data = orig_data
                        nan_detected = True
                        continue
                    
                # Apply weight norm clipping to every parameter, not only linear/conv layers, but also BN
                weight_norm = torch.norm(param.data)
                if weight_norm > self.max_weight_norm:
                    print(f"Weight norm {weight_norm:.4f} exceeds max_weight_norm {self.max_weight_norm:.4f}, clipping")
                    param.data = param.data * (self.max_weight_norm / weight_norm)
            
            torch.nn.utils.clip_grad_norm_(group['params'], max_norm=self.max_grad_norm)

    def zero_grad(self):
        """Clear gradients in the wrapped optimizer."""
        self.optimizer.zero_grad()