import torch.nn as nn
from torchvision.models import mobilenet_v2, mobilenet_v3_small, mobilenet_v3_large


def reshape_mobilenet(net, input_shape=(3, 224, 224), num_classes=1000):
    if input_shape[0] != 3:
        net.features[0][0] = nn.Conv2d(input_shape[0], net.features[0][0].out_channels, kernel_size=3, stride=2, padding=1, bias=False)
        nn.init.kaiming_normal_(net.features[0][0].weight, mode="fan_out")

    if num_classes != 1000:
        net.classifier[-1] = nn.Linear(net.classifier[-1].in_features, num_classes)
        nn.init.normal_(net.classifier[-1].weight, 0, 0.01)
        nn.init.zeros_(net.classifier[-1].bias)
    return net


def MobileNet_V2(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_mobilenet(mobilenet_v2(), input_shape, num_classes)

def MobileNet_V3_Small(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_mobilenet(mobilenet_v3_small(), input_shape, num_classes)

def MobileNet_V3_Large(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_mobilenet(mobilenet_v3_large(), input_shape, num_classes)
