import torch.nn as nn
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152


def reshape_resnet(net, input_shape=(3, 224, 224), num_classes=1000):
    if input_shape[0] != 3:
        net.conv1 = nn.Conv2d(input_shape[0], net.conv1.out_channels, kernel_size=7, stride=2, padding=3, bias=False)
        nn.init.kaiming_normal_(net.conv1.weight, mode="fan_out", nonlinearity="relu")

    if num_classes != 1000:
        net.fc = nn.Linear(net.fc.in_features, num_classes)

    return net


def ResNet18(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_resnet(resnet18(), input_shape, num_classes)

def ResNet34(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_resnet(resnet34(), input_shape, num_classes)

def ResNet50(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_resnet(resnet50(), input_shape, num_classes)

def ResNet101(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_resnet(resnet101(), input_shape, num_classes)

def ResNet152(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_resnet(resnet152(), input_shape, num_classes)