import math
import torch.nn as nn
from torchvision.models import efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7, efficientnet_v2_s, efficientnet_v2_m, efficientnet_v2_l


def reshape_efficientnet(net, input_shape=(3, 224, 224), num_classes=1000):
    if input_shape[0] != 3:
        net.features[0][0] = nn.Conv2d(input_shape[0], net.features[0][0].out_channels, kernel_size=3, stride=2, padding=1, bias=False)
        nn.init.kaiming_normal_(net.features[0][0].weight, mode="fan_out")

    if num_classes != 1000:
        net.classifier[1] = nn.Linear(net.classifier[1].in_features, num_classes)
        init_range = 1.0 / math.sqrt(num_classes)
        nn.init.uniform_(net.classifier[1].weight, -init_range, init_range)
        nn.init.zeros_(net.classifier[1].bias)

    return net


def EfficientNet_B0(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_b0(), input_shape, num_classes)

def EfficientNet_B1(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_b1(), input_shape, num_classes)

def EfficientNet_B2(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_b2(), input_shape, num_classes)

def EfficientNet_B3(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_b3(), input_shape, num_classes)

def EfficientNet_B4(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_b4(), input_shape, num_classes)

def EfficientNet_B5(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_b5(), input_shape, num_classes)

def EfficientNet_B6(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_b6(), input_shape, num_classes)

def EfficientNet_B7(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_b7(), input_shape, num_classes)

def EfficientNet_V2_S(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_v2_s(), input_shape, num_classes)

def EfficientNet_V2_M(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_v2_m(), input_shape, num_classes)

def EfficientNet_V2_L(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_efficientnet(efficientnet_v2_l(), input_shape, num_classes)