# Reference: DiffusionPolicy [https://github.com/real-stanford/diffusion_policy]

import torch
from torch.nn.modules.batchnorm import _BatchNorm


class EMAModel:
    """
    Exponential Moving Average of models weights
    """
    def __init__(
        self,
        model
    ):
        self.averaged_model = model
        self.averaged_model.eval()
        self.averaged_model.requires_grad_(False)

        self.decay = 0.99

    @torch.no_grad()
    def step(self, new_model):
        for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):            
            for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):
                ema_param.mul_(self.decay).add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)

