import torch
import torch.nn as nn
from antgine.regularizer import AbstractRegularizer


class L2(AbstractRegularizer):
    """
        L2 regularizer.
    """
    def __init__(self, model, lambda_,
                 modules_attrs={'weight': [nn.Linear, nn.Conv2d, nn.BatchNorm1d, nn.BatchNorm2d]}):
        """
            See :meth:`antgine.regulizer.AbstractRegularizer.__init__`.
        `module_attrs` should contains 'weight' key.
        :param float lambda_: Penalty value.
        """
        super().__init__(model=model, modules_attrs=modules_attrs)
        assert 'weight' in modules_attrs
        self._lambda = lambda_

    def forward(self, epoch, it):
        reg = 0
        for l in self._layerparams:
            for p in l.values():
                reg += torch.sum(p ** 2)
        return self._lambda * reg
