from collections import OrderedDict
import torch.nn as nn
from torchvision.models import vgg11, vgg13, vgg16, vgg19


def reshape_vgg(net, input_shape=(3, 224, 224), num_classes=1000):
    if input_shape[0] != 3:
        net.features[0] = nn.Conv2d(input_shape[0], net.features[0].out_channels, kernel_size=3, padding=1)
        nn.init.kaiming_normal_(net.features[0].weight, mode="fan_out", nonlinearity="relu")
        nn.init.constant_(net.features[0].bias, 0)

    padding_w, padding_h = max(0, 32 - input_shape[1]), max(0, 32 - input_shape[2])
    if max(padding_w, padding_h) > 0:
        net.features = nn.Sequential(
            OrderedDict(
                [
                    ("input_padding", nn.ZeroPad2d((
                        padding_w // 2, padding_w - padding_w // 2,
                        padding_h // 2, padding_h - padding_h // 2))),
                    ("features", net.features),
                ]
            )
        )

    if num_classes != 1000:
        net.classifier[6] = nn.Linear(net.classifier[6].in_features, num_classes)
        nn.init.normal_(net.classifier[6].weight, 0, 0.01)
        nn.init.constant_(net.classifier[6].bias, 0)

    return net


def VGG11(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_vgg(vgg11(), input_shape, num_classes)

def VGG13(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_vgg(vgg13(), input_shape, num_classes)

def VGG16(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_vgg(vgg16(), input_shape, num_classes)

def VGG19(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_vgg(vgg19(), input_shape, num_classes)
