from collections import OrderedDict
import torch.nn as nn
from torchvision.models import alexnet


def reshape_alexnet(net, input_shape=(3, 224, 224), num_classes=1000):
    if input_shape[0] != 3:
        net.features[0] = nn.Conv2d(input_shape[0], 64, kernel_size=11, stride=4, padding=2)

    padding_w, padding_h = max(0, 64 - input_shape[1]), max(0, 64 - input_shape[2])
    if max(padding_w, padding_h) > 0:
        net.features = nn.Sequential(
            OrderedDict(
                [
                    ("input_padding", nn.ZeroPad2d((
                        padding_w // 2, padding_w - padding_w // 2,
                        padding_h // 2, padding_h - padding_h // 2))),
                    ("features", net.features),
                ]
            )
        )

    if num_classes != 1000:
        net.classifier[6] = nn.Linear(net.classifier[6].in_features, num_classes)

    return net


def AlexNet(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_alexnet(alexnet(), input_shape, num_classes)