import torch
from torch.nn import Module
from copy import deepcopy
from torch.nn.modules.batchnorm import _BatchNorm


class AveragedModel(Module):
    """
    Modified AveragedModel from torch swa_utils that supports EMA and SWA updates and batch norm averaging
    """

    def __init__(
        self, model, avg_type="ema", ema_decay=0.990, avg_batchnorm=False, device=None
    ):
        super(AveragedModel, self).__init__()
        self.module = deepcopy(model)
        if device is not None:
            self.module = self.module.to(device)
        self.register_buffer(
            "n_averaged", torch.tensor(0, dtype=torch.long, device=device)
        )

        assert avg_type in ["ema", "swa"]
        self.avg_type = avg_type
        self.ema_decay = ema_decay
        self.avg_batchnorm = avg_batchnorm

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def update_parameters(self, model):
        n = self.n_averaged.item()
        if self.avg_type == "ema":
            decay = min(self.ema_decay, (1 + n) / (10 + n))
            def avg_fn(averaged_model_parameter, model_parameter):
                return decay * averaged_model_parameter + (1.0 - decay) * model_parameter
        elif self.avg_type == "swa":
            def avg_fn(averaged_model_parameter, model_parameter):
                return (model_parameter - averaged_model_parameter) / (n + 1)
        else:
            raise NotImplementedError()

        for p_swa, p_model in zip(self.parameters(), model.parameters()):
            device = p_swa.device
            p_model_ = p_model.detach().to(device)
            if n == 0:
                p_swa.detach().copy_(p_model_)
            else:
                p_swa.detach().copy_(
                    avg_fn(
                        p_swa.detach(),
                        p_model_,
                    )
                )

        if self.avg_batchnorm:
            for avg_mod, model_mod in zip(self.module.modules(), model.modules()):
                if issubclass(type(model_mod), _BatchNorm):
                    device = avg_mod.running_mean.device
                    mean_model_ = model_mod.running_mean.detach().to(device)
                    var_model_ = model_mod.running_var.detach().to(device)
                    if n == 0:
                        avg_mod.running_mean.detach().copy_(mean_model_)
                        avg_mod.running_var.detach().copy_(var_model_)
                    else:
                        avg_mod.running_mean.detach().copy_(
                            avg_fn(avg_mod.running_mean.detach(), mean_model_)
                        )
                        avg_mod.running_var.detach().copy_(
                            avg_fn(avg_mod.running_var.detach(), var_model_)
                        )

        self.n_averaged += 1
