"""Copyright (c) Meta Platforms, Inc. and affiliates."""
#https://github.com/kksniak/metric-flow-matching/blob/main/mfm/flow_matchers/ema.py
import torch


class EMA(torch.nn.Module):
    def __init__(self, model: torch.nn.Module, decay: float = 0.999):
        super().__init__()
        self.model = model
        self.decay = decay
        if hasattr(self.model, "time_geopath"):
            self.time_geopath = self.model.time_geopath

        # Put this in a buffer so that it gets included in the state dict
        self.register_buffer("num_updates", torch.tensor(0))

        self.shadow_params = torch.nn.ParameterList(
            [
                torch.nn.Parameter(p.clone().detach(), requires_grad=False)
                for p in model.parameters()
                if p.requires_grad
            ]
        )
        self.backup_params = []

    # #TODO: test this
    # def reinit(self):

    #     self.model.reinit()
        
    #     self.num_updates.zero_()
        
    #     self.shadow_params = torch.nn.ParameterList(
    #         [
    #             torch.nn.Parameter(p.clone().detach(), requires_grad=False)
    #             for p in model.parameters()
    #             if p.requires_grad
    #         ]
        
    #     # Clear backup parameters
    #     self.backup_params = []
        
    #     # Ensure we're in training mode after reinit
    #     self.train(True)

    def train(self, mode: bool):
        if self.training and mode == False:
            # Switching from train mode to eval mode.  Backup the model parameters and
            # overwrite with shadow params
            self.backup()
            self.copy_to_model()
        elif not self.training and mode == True:
            # Switching from eval to train mode.  Restore the `backup_params`
            self.restore_to_model()

        super().train(mode)

    def update_ema(self):
        self.num_updates += 1
        num_updates = self.num_updates.item()
        decay = min(self.decay, (1 + num_updates) / (10 + num_updates))
        with torch.no_grad():
            params = [p for p in self.model.parameters() if p.requires_grad]
            for shadow, param in zip(self.shadow_params, params):
                shadow.sub_((1 - decay) * (shadow - param))

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def copy_to_model(self):
        # copy the shadow (ema) parameters to the model
        params = [p for p in self.model.parameters() if p.requires_grad]
        for shaddow, param in zip(self.shadow_params, params):
            param.data.copy_(shaddow.data)

    def backup(self):
        # Backup the current model parameters
        if len(self.backup_params) > 0:
            for p, b in zip(self.model.parameters(), self.backup_params):
                b.data.copy_(p.data)
        else:
            self.backup_params = [param.clone() for param in self.model.parameters()]

    def restore_to_model(self):
        # Restores the backed up parameters to the model.
        for param, backup in zip(self.model.parameters(), self.backup_params):
            param.data.copy_(backup.data)
