from collections import OrderedDict
import torch.nn as nn
from torchvision.models import convnext_tiny, convnext_small, convnext_base, convnext_large


def reshape_convnext(net, input_shape=(3, 224, 224), num_classes=1000):
    if input_shape[0] != 3:
        net.features[0][0] = nn.Conv2d(input_shape[0], net.features[0][0].out_channels, kernel_size=4, stride=4,
                                       padding=0, bias=True)
        nn.init.trunc_normal_(net.features[0][0].weight, std=0.02)
        nn.init.zeros_(net.features[0][0].bias)

    padding_w, padding_h = max(0, 32 - input_shape[1]), max(0, 32 - input_shape[2])
    if max(padding_w, padding_h) > 0:
        net.features = nn.Sequential(
            OrderedDict(
                [
                    ("input_padding", nn.ZeroPad2d((
                        padding_w // 2, padding_w - padding_w // 2,
                        padding_h // 2, padding_h - padding_h // 2))),
                    ("features", net.features),
                ]
            )
        )

    if num_classes != 1000:
        net.classifier[2] = nn.Linear(net.classifier[2].in_features, num_classes)
        nn.init.trunc_normal_(net.classifier[2].weight, std=0.02)
        nn.init.zeros_(net.classifier[2].bias)

    return net


def ConvNeXt_Tiny(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_convnext(convnext_tiny(), input_shape, num_classes)


def ConvNeXt_Small(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_convnext(convnext_small(), input_shape, num_classes)


def ConvNeXt_Base(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_convnext(convnext_base(), input_shape, num_classes)


def ConvNeXt_Large(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_convnext(convnext_large(), input_shape, num_classes)
