from collections import OrderedDict
import torch.nn as nn
from torchvision.models import densenet121, densenet161, densenet169, densenet201


def reshape_densenet(net, input_shape=(3, 224, 224), num_classes=1000):
    if input_shape[0] != 3:
        net.features.conv0 = nn.Conv2d(input_shape[0], net.features.conv0.out_channels, kernel_size=7, stride=2, padding=3, bias=False)
        nn.init.kaiming_normal_(net.features.conv0.weight)

    padding_w, padding_h = max(0, 32 - input_shape[1]), max(0, 32 - input_shape[2])
    if max(padding_w, padding_h) > 0:
        net.features = nn.Sequential(
            OrderedDict(
                [
                    ("input_padding", nn.ZeroPad2d((
                        padding_w // 2, padding_w - padding_w // 2,
                        padding_h // 2, padding_h - padding_h // 2))),
                    ("features", net.features),
                ]
            )
        )

    if num_classes != 1000:
        net.classifier = nn.Linear(net.classifier.in_features, num_classes)
        nn.init.constant_(net.classifier.bias, 0)

    return net


def DenseNet121(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_densenet(densenet121(), input_shape, num_classes)

def DenseNet161(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_densenet(densenet161(), input_shape, num_classes)

def DenseNet169(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_densenet(densenet169(), input_shape, num_classes)

def DenseNet201(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_densenet(densenet201(), input_shape, num_classes)