from federatedscope.register import register_regularizer
try:
    from torch.nn import Module
    import torch
except ImportError:
    Module = object
    torch = None

REGULARIZER_NAME = "proximal_regularizer"


class ProximalRegularizer(Module):
    """Returns the norm of the specific weight update.

        Arguments:
            p (int): The order of norm.
            tensor_before: The original matrix or vector
            tensor_after: The updated matrix or vector

        Returns:
            Tensor: the norm of the given udpate.
    """
    def __init__(self):
        super(ProximalRegularizer, self).__init__()

    def forward(self, ctx, p=2):
        norm = 0.
        for w_init, w in zip(ctx.weight_init, ctx.model.parameters()):
            norm += torch.pow(torch.norm(w - w_init, p), p)
        return norm * 1. / float(p)


def call_proximal_regularizer(type):
    if type == REGULARIZER_NAME:
        regularizer = ProximalRegularizer
        return regularizer


register_regularizer(REGULARIZER_NAME, call_proximal_regularizer)
