
import torch
import copy

class EMA:
    """Implements exponential moving average by maintaining a copy of the model parameters
     update then restor changes the model parameters to P' = decay*P + (1-decay)*P' 
    """
    def __init__(self, model, decay=0.01):
        self.model = model
        self.decay = decay
        self.shadow = copy.deepcopy(model)
        for param in self.shadow.parameters():
            param.requires_grad = False

    def update(self):
        with torch.no_grad():
            for shadow_param, param in zip(self.shadow.parameters(), self.model.parameters()):
                shadow_param.copy_(self.decay * param + (1 - self.decay) * shadow_param)

    def apply_shadow(self):
        self.backup = copy.deepcopy(self.model.state_dict())
        self.model.load_state_dict(self.shadow.state_dict())

    def restore(self):
        self.model.load_state_dict(self.backup)