
import torch.nn as nn


def weights_init(m, init=nn.init.kaiming_normal):
    """Initialize model weights

    call: model.apply(weights_init)
    """

    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        init(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
