"""
Exponential Moving Average (EMA) helper for PyTorch models.

This class maintains a "shadow" copy of a model's weights, which are
updated with a slow-moving average of the main model's weights. This
often leads to more stable training and better final performance.
"""
import torch
from torch import nn

class EMA:
    """
    Exponential Moving Average of model weights.
    
    Args:
        model (nn.Module): The model to track.
        decay (float): The decay factor for the moving average.
    """
    def __init__(self, model: nn.Module, decay: float):
        self.decay = decay
        self.shadow = {}
        
        # Register the shadow parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, model: nn.Module):
        """
        Update the shadow weights with the current model weights.
        This should be called after each optimizer.step().
        """
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self, model: nn.Module):
        """
        Copy the shadow weights to the model. This is used for evaluation.
        """
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                param.data = self.shadow[name]

    def restore(self, model: nn.Module, backup: dict):
        """
        Restore the original model weights from a backup.
        """
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in backup
                param.data = backup[name]

    def __enter__(self, model: nn.Module):
        """
        Context manager to apply shadow weights for evaluation.
        Usage:
            with ema.apply_shadow(model):
                # model now has EMA weights
                ...
        """
        self.backup = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
        self.apply_shadow(model)
        return self

    def __exit__(self, model: nn.Module, *args):
        """
        Restore original weights after evaluation.
        """
        self.restore(model, self.backup)

