import torch
from .sam import SAM

class SAMWrapper(torch.optim.Optimizer):
    """
    Wrapper for SAM optimizer to make it compatible with HuggingFace Trainer
    """
    def __init__(self, params, base_optimizer, **kwargs):
        # Initialize internal parameters
        self.params = list(params)
        self.base_optimizer_cls = base_optimizer
        self.kwargs = kwargs
        
        # Initialize standard optimizer interface
        defaults = dict(**kwargs)
        super(SAMWrapper, self).__init__(self.params, defaults)
        
        # Create actual SAM optimizer - but only used in step()
        self.sam = None
        self.first_step_done = False
        self.accumulated_grads = {}
    
    def step(self, closure=None):
        """Single-step call to implement two-step optimization process"""
        # Create SAM optimizer on first call
        if self.sam is None:
            self.sam = SAM(self.params, self.base_optimizer_cls, **self.kwargs)
            
        loss = None
        if not self.first_step_done:
            # First step: save current gradients, compute and apply perturbation
            # First save all gradients for use in second step
            for p in self.params:
                if p.grad is not None:
                    # Deep copy gradients to avoid reference modification
                    self.accumulated_grads[p] = p.grad.detach().clone()
            
            # Execute first step
            self.sam.first_step(zero_grad=False)
            self.first_step_done = True
            
            # Do not return loss, let trainer continue with next batch
            return loss
        else:
            # Second step: restore parameters, apply saved gradients, perform actual optimization
            
            # Restore saved gradients
            for p in self.params:
                if p in self.accumulated_grads:
                    if p.grad is None:
                        p.grad = self.accumulated_grads[p]
                    else:
                        p.grad.copy_(self.accumulated_grads[p])
            
            # Execute second step
            self.sam.second_step()
            
            # Reset state
            self.first_step_done = False
            self.accumulated_grads = {}
            
            return loss
    
    def zero_grad(self, set_to_none=False):
        """Clear gradients"""
        if hasattr(self.sam, 'zero_grad'):
            self.sam.zero_grad(set_to_none)
        else:
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is not None:
                        if set_to_none:
                            p.grad = None
                        else:
                            p.grad.zero_()
    
    def state_dict(self):
        """Get state"""
        return {
            'base_state': super().state_dict(),
            'sam_state': self.sam.state_dict() if self.sam else None,
            'first_step_done': self.first_step_done,
        }
    
    def load_state_dict(self, state_dict):
        """Load state"""
        super().load_state_dict(state_dict['base_state'])
        if state_dict['sam_state'] and self.sam:
            self.sam.load_state_dict(state_dict['sam_state'])
        self.first_step_done = state_dict['first_step_done']