import torch
import torch.nn.functional as F


class L2Constrain(object):

    def __init__(self, axis=-1, eps=1e-6):
        self.axis = axis
        self.eps = eps

    def __call__(self, module):
        if hasattr(module, 'weight'):
            w = module.weight.data
            module.weight.data = F.normalize(w, p=2, dim=self.axis, eps=self.eps)


class NonNegConstrain(object):

    def __init__(self, eps=1e-3):
        self.eps = eps

    def __call__(self, module):
        if hasattr(module, 'weight'):
            w = module.weight.data
            module.weight.data = torch.clamp(w, min=self.eps)
