import torch
import pytorch_lightning as pl
from typing import Dict, Any
import copy


class AdaptiveEMACallback(pl.callbacks.Callback):
    """Implements Exponential Moving Average with adaptive decay rate for PyTorch Lightning."""
    
    def __init__(self, decay=0.9999):
        """
        Initialize the AdaptiveEMACallback.
        
        Args:
            decay: Base decay rate (will be adapted during training)
        """
        super().__init__()
        self.decay = decay
        self.shadow_params = None
        self.num_updates = 0
        self.temp_stored_params = None
    
    def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        """Create a shadow copy of model parameters at the beginning of training."""
        # Initialize shadow parameters
        self.shadow_params = [p.clone().detach() 
                             for p in pl_module.parameters() 
                             if p.requires_grad]
        
    def on_train_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, 
                          outputs: Dict[str, Any], batch: Any, batch_idx: int) -> None:
        """Update shadow parameters at the end of each batch."""
        # Get current model parameters
        parameters = [p for p in pl_module.parameters() if p.requires_grad]
        
        # Calculate adaptive decay rate
        decay = self.decay
        self.num_updates += 1
        decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
        one_minus_decay = 1.0 - decay
        
        # Update shadow parameters
        with torch.no_grad():
            for s_param, param in zip(self.shadow_params, parameters):
                s_param.sub_(one_minus_decay * (s_param - param))
    
    def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        """Swap model parameters with shadow parameters for validation."""
        # Store current parameters to restore later
        self.temp_stored_params = [p.clone().detach() 
                                  for p in pl_module.parameters() 
                                  if p.requires_grad]
        
        # Copy shadow parameters to model for evaluation
        # import pdb; pdb.set_trace()
        with torch.no_grad():
            for param, s_param in zip([p for p in pl_module.parameters() if p.requires_grad], 
                                     self.shadow_params):
                param.copy_(s_param)
    
    def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        """Restore original model parameters after validation."""
        if self.temp_stored_params is None:
            return
            
        # Restore original parameters
        with torch.no_grad():
            for param, orig_param in zip([p for p in pl_module.parameters() if p.requires_grad], 
                                        self.temp_stored_params):
                param.copy_(orig_param)
        
        self.temp_stored_params = None
    
    def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        """Swap model parameters with shadow parameters for testing."""
        # Similar to validation but we might not restore after testing
        self.temp_stored_params = [p.clone().detach() 
                                  for p in pl_module.parameters() 
                                  if p.requires_grad]
        
        with torch.no_grad():
            for param, s_param in zip([p for p in pl_module.parameters() if p.requires_grad], 
                                     self.shadow_params):
                param.copy_(s_param)
    
    def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        """Optionally restore original parameters after testing."""
        # Depending on your use case, you might want to keep EMA params after testing
        # If you want to restore, uncomment the following:
        if self.temp_stored_params is None:
            return
            
        with torch.no_grad():
            for param, orig_param in zip([p for p in pl_module.parameters() if p.requires_grad], 
                                        self.temp_stored_params):
                param.copy_(orig_param)
        
        self.temp_stored_params = None
    
    def on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, 
                          checkpoint: Dict[str, Any]) -> None:
        """Save EMA parameters in the checkpoint."""
        if self.shadow_params is not None:
            checkpoint['ema_shadow_params'] = [p.clone().detach().cpu() for p in self.shadow_params]
            checkpoint['ema_num_updates'] = self.num_updates
    
    def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, 
                          checkpoint: Dict[str, Any]) -> None:
        """Load EMA parameters from the checkpoint."""
        if 'ema_shadow_params' in checkpoint:
            self.shadow_params = [p.to(pl_module.device) for p in checkpoint['ema_shadow_params']]
            self.num_updates = checkpoint.get('ema_num_updates', 0)
    
    def get_ema_model(self, pl_module: pl.LightningModule) -> pl.LightningModule:
        """Create a copy of the model with EMA parameters for inference."""
        ema_model = copy.deepcopy(pl_module)
        with torch.no_grad():
            for param, ema_param in zip([p for p in ema_model.parameters() if p.requires_grad], self.shadow_params):
                param.copy_(ema_param)
        return ema_model
    

