import torch.nn as nn


def initialize_parameters(model):

    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="leaky_relu")
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.constant_(module.weight, 1.0)
            nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            nn.init.constant_(module.bias, 0.0)


class Network(nn.Module):

    def __init__(self, backbone, head):
        super().__init__()
        self.backbone = backbone
        self.head = head
        initialize_parameters(self)

    def feature(self, input):
        return self.backbone(input)

    def predict(self, input):
        return self.head(input)

    def forward(self, input):
        return self.head(self.backbone(input))