import torch.nn as nn


class EMAHelper(object):
  def __init__(self, mu=0.999):
    self.mu = mu
    self.shadow = {}

  def register(self, module):
    if isinstance(module, nn.DataParallel):
      module = module.module
    for name, param in module.named_parameters():
      if param.requires_grad:
        self.shadow[name] = param.data.clone()

  def update(self, module):
    if isinstance(module, nn.DataParallel):
      module = module.module
    for name, param in module.named_parameters():
      if param.requires_grad:
        self.shadow[name].data = (1. - self.mu) * param.data + self.mu * \
                                 self.shadow[name].data

  def ema(self, module):
    if isinstance(module, nn.DataParallel):
      module = module.module
    for name, param in module.named_parameters():
      if param.requires_grad:
        param.data.copy_(self.shadow[name].data)

  def ema_copy(self, module):
    if isinstance(module, nn.DataParallel):
      inner_module = module.module
      module_copy = type(inner_module)(
        inner_module.config).to(inner_module.config.device)
      module_copy.load_state_dict(inner_module.state_dict())
      module_copy = nn.DataParallel(module_copy)
    else:
      module_copy = type(module)(module.config).to(module.config.device)
      module_copy.load_state_dict(module.state_dict())
    # module_copy = copy.deepcopy(module)
    self.ema(module_copy)
    return module_copy

  def state_dict(self):
    return self.shadow

  def load_state_dict(self, state_dict):
    self.shadow = state_dict
