import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


class BasicBlock(nn.Module):
    """
    A basic residual block for ResNet-18 and ResNet-34.
    """
    mul = 1  # Multiplicative factor for output channels (1 for BasicBlock, typically 4 for Bottleneck)

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # If stride is not 1, or the input and output planes differ,
        # we need to adapt the shortcut to match dimensions.
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    """
    A basic ResNet architecture built manually for small-scale usage (e.g., ResNet-18).
    """

    def __init__(self, block, num_blocks, num_classes=10):
        """
        Args:
            block: Residual block type (e.g., BasicBlock).
            num_blocks: A list with the number of blocks in each layer.
            num_classes: Number of classes for the final fully connected layer.
        """
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.input_conv = nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2, padding=3)
        self.input_bn = nn.BatchNorm2d(self.in_planes)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear1 = nn.Linear(512 * block.mul, num_classes)

    def _make_layer(self, block, out_planes, num_blocks, stride):
        """
        Creates a sequential layer made of 'num_blocks' residual blocks.
        """
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, out_planes, s))
            self.in_planes = block.mul * out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        """
        Standard forward pass: entire network.
        """
        out = self.input_conv(x)
        out = self.input_bn(out)
        out = F.relu(out)
        out = self.maxpool(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.linear1(out)
        return out

    def split_edge_output(self, x):
        """
        Returns the early (edge) output up to layer2.
        """
        out = self.input_conv(x)
        out = self.input_bn(out)
        out = F.relu(out)
        out = self.maxpool(out)

        out = self.layer1(out)
        out = self.layer2(out)
        return out

    def split_cloud_output(self, x):
        """
        Receives output from split_edge_output and continues from layer3 onward.
        """
        out = self.layer3(x)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.linear1(out)
        return out


class ResNet34Custom(nn.Module):
    """
    A wrapper around torchvision.models.resnet34 to adjust the final layer
    for different datasets and to provide split outputs.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(ResNet34Custom, self).__init__()
        # Load a torchvision ResNet34, optionally with pretrained weights.
        if pretrained:
            model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        else:
            model = models.resnet34(weights=None)

        # Adjust the final fc layer for different datasets.
        if data_name == "CIFAR100":
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 100)
        elif data_name == "CIFAR10":
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 10)

        self.model = model
        # Store feature-extraction layers and the final linear separately if needed.
        self.layers = list(self.model.children())[:-1]  # until avgpool
        self.linear = list(self.model.children())[-1:]  # fc

    def forward(self, x):
        """
        Standard forward pass (entire ResNet34).
        """
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        """
        Forward up to a certain index (target_layer) among the stored layers.
        Returns early (edge) feature maps.
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        """
        Takes the output of split_edge_output and runs the remaining layers plus fc.
        """
        for layer in self.layers[target_layer + 1:]:
            x = layer(x)
        x = x.view(x.size(0), -1)
        for layer in self.linear:
            x = layer(x)
        return x


class ResNet50Custom(nn.Module):
    """
    A wrapper around torchvision.models.resnet50 with optional pretrained weights,
    dataset-specific fc adjustment, and split-output methods.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(ResNet50Custom, self).__init__()
        if pretrained:
            model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        else:
            model = models.resnet50(weights=None)

        if data_name == "CIFAR100":
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 100)
        elif data_name == "CIFAR10":
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 10)

        self.model = model
        self.layers = list(self.model.children())[:-1]  # until avgpool
        self.linear = list(self.model.children())[-1:]  # fc

    def forward(self, x):
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        """
        Forward pass until the specified target_layer index in self.layers.
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_umap(self, x, target_layer):
        """
        Example method to get intermediate features without final pooling/FC.
        """
        for layer in self.layers[target_layer + 1:]:
            x = layer(x)
        return x

    def split_cloud_output(self, x, target_layer):
        """
        Takes partial output from split_edge_output (target_layer) and completes forward pass.
        """
        for layer in self.layers[target_layer + 1:]:
            x = layer(x)
        x = x.view(x.size(0), -1)
        for layer in self.linear:
            x = layer(x)
        return x


class ResNet101Custom(nn.Module):
    """
    A wrapper around torchvision.models.resnet101 with optional pretrained weights,
    dataset-specific fc adjustment, and split-output methods.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(ResNet101Custom, self).__init__()
        if pretrained:
            model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
        else:
            model = models.resnet101(weights=None)

        if data_name == "CIFAR100":
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 100)
        elif data_name == "CIFAR10":
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 10)

        self.model = model
        # [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool]
        self.layers = list(self.model.children())[:-1]
        # [fc]
        self.linear = list(self.model.children())[-1:]

    def forward(self, x):
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        """
        Forward pass up to a specific layer index (target_layer).
        Example: target_layer=3 would yield output after maxpool.
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        """
        Receives the partial output (edge output) and continues from target_layer+1 to the end,
        then applies the fc layer.
        """
        for layer in self.layers[target_layer + 1:]:
            x = layer(x)
        x = x.view(x.size(0), -1)
        for layer in self.linear:
            x = layer(x)
        return x


class ResNet152Custom(nn.Module):
    """
    A wrapper around torchvision.models.resnet152 with optional pretrained weights,
    dataset-specific fc adjustment, and split-output methods.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(ResNet152Custom, self).__init__()
        if pretrained:
            model = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V2)
        else:
            model = models.resnet152(weights=None)

        if data_name == "CIFAR100":
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 100)
        elif data_name == "CIFAR10":
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 10)

        self.model = model
        self.layers = list(self.model.children())[:-1]
        self.linear = list(self.model.children())[-1:]

    def forward(self, x):
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        """
        Forward pass until the specified target_layer.
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        """
        Forward pass from target_layer+1 through the end, then applies the fc layer.
        """
        for layer in self.layers[target_layer + 1:]:
            x = layer(x)
        x = x.view(x.size(0), -1)
        for layer in self.linear:
            x = layer(x)
        return x

class MobileNetV2Custom(nn.Module):
    """
    A wrapper around torchvision.models.mobilenet_v2 with optional pretrained weights,
    dataset-specific classifier adjustment, and split-output methods.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(MobileNetV2Custom, self).__init__()
        # (1) 모델 불러오기
        if pretrained:
            # 최신 PyTorch(>=1.13)에서 MobileNet V2의 pretrained weights
            model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
        else:
            model = models.mobilenet_v2(weights=None)  # 랜덤 초기화

        # (2) 원하는 데이터셋에 맞춰 최종 분류기 수정
        # 모델의 classifier는 보통 [Dropout, Linear]로 구성되어 있으며
        # model.classifier[-1] (또는 model.classifier[1])이 nn.Linear 레이어가 됩니다.
        if data_name == "CIFAR100":
            in_features = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_features, 10)
        # 그 외 "ImageNet" 등은 기본 1000 클래스

        # (3) 전체 모델을 self.model로 보관
        self.model = model

        self.features = self.model.features  # nn.Sequential(...)
        self.classifier = self.model.classifier  # nn.Sequential(Dropout, Linear)
        
        # nn.Sequential을 리스트로 풀어서 각 Block 단위로 접근 가능하게 만듦
        self.layers = list(self.features.children())  # 여러 InvertedResidual block + 마지막 Conv

    def forward(self, x):

        return self.model(x)

    def split_edge_output(self, x, target_layer):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        # 만약 target_layer가 layers 범위보다 크지 않게 설정되면 보통 여기에 도달하지 않음
        return x

    def split_cloud_output(self, x, target_layer):

        # (1) 남은 feature 블록들을 적용
        for layer in self.layers[target_layer + 1:]:
            x = layer(x)

        x = F.adaptive_avg_pool2d(x, (1,1))
        x = torch.flatten(x, 1)

        x = self.classifier(x)
        return x


def MobNetV2(pretrained=True, data_name="ImageNet"):
    """
    Constructs a MobileNetV2Custom model.
    """
    return MobileNetV2Custom(pretrained=pretrained, data_name=data_name)


# Factory functions for constructing models.
def Res18(num_classes=10):
    """
    Constructs a manual ResNet-18 using BasicBlock.
    """
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)


def Res34(pretrained=True, data_name="ImageNet"):
    """
    Constructs a ResNet34Custom model based on torchvision's implementation.
    """
    return ResNet34Custom(pretrained=pretrained, data_name=data_name)


def Res50(pretrained=True, data_name="ImageNet"):
    """
    Constructs a ResNet50Custom model.
    """
    return ResNet50Custom(pretrained=pretrained, data_name=data_name)


def Res101(pretrained=True, data_name="ImageNet"):
    """
    Constructs a ResNet101Custom model.
    """
    return ResNet101Custom(pretrained=pretrained, data_name=data_name)


def Res152(pretrained=True, data_name="ImageNet"):
    """
    Constructs a ResNet152Custom model.
    """
    return ResNet152Custom(pretrained=pretrained, data_name=data_name)



import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class ViTB16Custom(nn.Module):
    """
    A custom wrapper around torchvision's ViT-B/16 model.
    Provides:
      - Pretrained loading
      - Dataset-specific classifier adjustment
      - Split-output methods (similar to ResNet wrappers)
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(ViTB16Custom, self).__init__()
        # 1) Load a ViT-B/16 model from torchvision
        #    Note: In newer PyTorch versions:
        #        model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
        #    or, for a simpler checkpoint:
        #        model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
        #    Adjust according to your desired pretrained checkpoint:
        if pretrained:
            model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
        else:
            model = models.vit_b_16(weights=None)  # random init

        # 2) Adjust final classifier for the desired dataset
        #    By default, ViT has model.heads = nn.Sequential(...), and
        #    the last layer is model.heads.head (Linear).
        if data_name == "CIFAR100":
            in_features = model.heads.head.in_features
            model.heads.head = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.heads.head.in_features
            model.heads.head = nn.Linear(in_features, 10)
        # else: keep ImageNet-1k head or whatever is default.

        self.model = model

        # 3) We can store children so we can do partial forward
        #    Typically, vit_b_16 has 3 children: [conv_proj, encoder, heads]
        #    We'll separate out the "heads" so it’s easier to replicate the
        #    pattern from the ResNet code.
        self.layers = list(self.model.children())[:-1]  # [conv_proj, encoder]
        self.classifier = list(self.model.children())[-1:]  # [heads]

    def forward(self, x):
        """
        Standard forward pass through the entire ViT-B/16.
        """
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        """
        Forward up to target_layer in self.layers. 
        If target_layer=0, we'll only apply 'conv_proj'.
        If target_layer=1, we'll apply 'conv_proj' and 'encoder'.
        
        Realistically, you probably want to stop AFTER the entire encoder 
        if you want meaningful features. i.e. target_layer = 1.
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        # if target_layer >= len(self.layers)-1, we just return after last
        return x

    def split_cloud_output(self, x, target_layer):
        """
        Resume forward from `target_layer + 1` in self.layers through classifier.
        For example, if you used target_layer=1 in split_edge_output,
        then pass that result here with the same target_layer to get final logits.
        """
        # Continue with the remaining layers
        for layer in self.layers[target_layer + 1:]:
            x = layer(x)
        # Now apply classifier
        for layer in self.classifier:
            x = layer(x)
        return x


def ViTB16(pretrained=True, data_name="ImageNet"):
    """
    Factory function to build a ViTB16Custom model.
    """
    return ViTB16Custom(pretrained=pretrained, data_name=data_name)


class SwinTCustom(nn.Module):
    """
    A custom wrapper around torchvision's Swin Transformer (Tiny).
    Provides:
      - Pretrained loading
      - Dataset-specific head adjustment
      - Split-output methods
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(SwinTCustom, self).__init__()
        if pretrained:
            model = models.swin_t(weights=models.Swin_T_Weights.IMAGENET1K_V1)
        else:
            model = models.swin_t(weights=None)  # random init

        # Adjust final classifier (model.head) for the dataset
        if data_name == "CIFAR100":
            in_features = model.head.in_features
            model.head = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.head.in_features
            model.head = nn.Linear(in_features, 10)

        self.model = model
        # According to torchvision, top-level children: [features, head]
        self.layers = list(self.model.children())[:-1]  # [features]
        self.classifier = list(self.model.children())[-1:]  # [head]

    def forward(self, x):
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        """
        Forward until the specified target_layer among self.layers.
        For Swin, there is typically just one big 'features' module 
        if you do list(self.model.children()), so target_layer can be 0 or 1, etc.
        
        If you want more granular splitting inside the 'features' module, 
        you need to break down model.features into its constituent stages.
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        for layer in self.layers[target_layer + 1:]:
            x = layer(x)
        x = self.classifier[0](x)  # model.head
        return x


def SwinT(pretrained=True, data_name="ImageNet"):
    return SwinTCustom(pretrained=pretrained, data_name=data_name)



class VGG16Custom(nn.Module):
    """
    A wrapper around torchvision.models.vgg16.
    Breaks out .features and .classifier for partial (split) inference.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(VGG16Custom, self).__init__()
        if pretrained:
            model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        else:
            model = models.vgg16(weights=None)

        # The last linear layer is model.classifier[6] in vanilla VGG16
        if data_name == "CIFAR100":
            in_features = model.classifier[6].in_features
            model.classifier[6] = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.classifier[6].in_features
            model.classifier[6] = nn.Linear(in_features, 10)
        # else leave as 1000 for ImageNet

        # Save as submodules for partial forwarding
        self.features = model.features      # nn.Sequential of convolutional layers
        self.avgpool = model.avgpool        # adaptive avgpool
        self.classifier = model.classifier  # nn.Sequential of linear layers

    def forward(self, x):
        # Standard forward
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def split_edge_output(self, x, target_layer):
        """
        Forward pass until target_layer among 'features' (which has 31 layers for VGG16).
        For example, target_layer=16 might get you features up to a certain conv layer.
        """
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i == target_layer:
                return x
        # If target_layer goes beyond features, we won't pool or classify here.
        return x

    def split_cloud_output(self, x, target_layer):
        """
        Takes the partial output from split_edge_output, completes the rest of 'features',
        then does avgpool + classifier.
        """
        # If we ended early in the features, continue from the next layer
        for i, layer in enumerate(self.features):
            if i <= target_layer:
                continue
            x = layer(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def VGG16(pretrained=True, data_name="ImageNet"):
    return VGG16Custom(pretrained=pretrained, data_name=data_name)

###############################################################################
# 2) DenseNet121
###############################################################################
class DenseNet121Custom(nn.Module):
    """
    A wrapper around torchvision.models.densenet121.
    model.features is a big nn.Sequential with convolutional+dense blocks.
    model.classifier is the final Linear.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(DenseNet121Custom, self).__init__()
        if pretrained:
            model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        else:
            model = models.densenet121(weights=None)

        # Adjust final classifier
        if data_name == "CIFAR100":
            in_features = model.classifier.in_features
            model.classifier = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.classifier.in_features
            model.classifier = nn.Linear(in_features, 10)

        self.model = model
        # Splitting:
        # self.model.features is a single nn.Sequential,
        # So we can store its children if we want finer splitting:
        self.features = model.features
        self.classifier = model.classifier

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def split_edge_output(self, x, target_layer):
        """
        Partial forward until target_layer in self.features.
        denseNet.features has quite a few submodules; enumerating them will give you
        [conv0, norm0, relu0, pool0, denseblock1, transition1, denseblock2, ...].
        """
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        for i, layer in enumerate(self.features):
            if i <= target_layer:
                continue
            x = layer(x)
        x = F.relu(x, inplace=True)
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def DenseNet121(pretrained=True, data_name="ImageNet"):
    return DenseNet121Custom(pretrained=pretrained, data_name=data_name)

###############################################################################
# 3) Inception v3
###############################################################################
class InceptionV3Custom(nn.Module):
    """
    A wrapper for torchvision.models.inception_v3.
    By default, inception v3 has an auxiliary classifier; we disable it (aux_logits=False).
    If you need that, handle it carefully in partial forward.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(InceptionV3Custom, self).__init__()
        # The default input size is 299x299 for Inception v3.
        # If pretrained, specify the weights:
        if pretrained:
            model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, aux_logits=False)
        else:
            model = models.inception_v3(weights=None, aux_logits=False)

        # Adjust classifier
        if data_name == "CIFAR100":
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, 10)

        # Potentially disable aux_logits if you do not want to train the auxiliary branch:
        model.aux_logits = False

        self.model = model
        # For partial forward, we can store each top-level child:
        self.layers = list(model.children())[:-1]  # everything except the final fc
        self.linear = list(model.children())[-1:]  # the final fc

    def forward(self, x):
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        # Typically you have: [Conv2d_1a_3x3, Conv2d_2a_3x3, ..., Mixed_5b, ..., Mixed_7c, avgpool, dropout]
        # Then fc is separate in self.linear
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        for i, layer in enumerate(self.layers):
            if i <= target_layer:
                continue
            x = layer(x)
        # Finally apply the fc
        x = self.linear[0](x)
        return x

def InceptionV3(pretrained=True, data_name="ImageNet"):
    return InceptionV3Custom(pretrained=pretrained, data_name=data_name)

###############################################################################
# 4) ShuffleNet v2 (x1.0)
###############################################################################
class ShuffleNetV2Custom(nn.Module):
    """
    A wrapper for torchvision.models.shufflenet_v2_x1_0.
    The final classifier is model.fc by default.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(ShuffleNetV2Custom, self).__init__()
        if pretrained:
            model = models.shufflenet_v2_x1_0(weights=models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)
        else:
            model = models.shufflenet_v2_x1_0(weights=None)

        # Adjust final fc
        if data_name == "CIFAR100":
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, 10)

        self.model = model
        # The top-level children: [conv1, maxpool, stages, conv5, fc]
        self.layers = list(self.model.children())[:-1]  # everything except fc
        self.classifier = list(self.model.children())[-1:]  # fc

    def forward(self, x):
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        for i, layer in enumerate(self.layers):
            if i <= target_layer:
                continue
            x = layer(x)
        x = x.mean([2, 3])  # globalpool
        for layer in self.classifier:
            x = layer(x)
        return x

def ShuffleNetV2(pretrained=True, data_name="ImageNet"):
    return ShuffleNetV2Custom(pretrained=pretrained, data_name=data_name)

###############################################################################
# 5) SqueezeNet 1.1
###############################################################################
class SqueezeNetCustom(nn.Module):
    """
    A wrapper for torchvision.models.squeezenet1_1.
    Final classifier is in model.classifier, typically model.classifier[1].
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(SqueezeNetCustom, self).__init__()
        if pretrained:
            model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1)
        else:
            model = models.squeezenet1_1(weights=None)

        # Adjust final conv for custom dataset
        # By default, squeezenet uses a final conv for classification:
        #   model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), ...)
        if data_name == "CIFAR100":
            model.classifier[1] = nn.Conv2d(512, 100, kernel_size=(1,1), stride=(1,1))
        elif data_name == "CIFAR10":
            model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))

        self.model = model
        # SqueezeNet top-level: [features, classifier]
        self.features = self.model.features
        self.classifier = self.model.classifier

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = self.classifier(x)
        x = torch.flatten(x, 1)
        return x

    def split_edge_output(self, x, target_layer):
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        for i, layer in enumerate(self.features):
            if i <= target_layer:
                continue
            x = layer(x)
        x = F.relu(x, inplace=True)
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = self.classifier(x)
        x = torch.flatten(x, 1)
        return x

def SqueezeNet(pretrained=True, data_name="ImageNet"):
    return SqueezeNetCustom(pretrained=pretrained, data_name=data_name)

###############################################################################
# 6) EfficientNet-B0
###############################################################################
class EfficientNetB0Custom(nn.Module):
    """
    A wrapper for torchvision.models.efficientnet_b0.
    The final classifier is model.classifier[-1] = nn.Linear(...).
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(EfficientNetB0Custom, self).__init__()
        if pretrained:
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        else:
            model = models.efficientnet_b0(weights=None)

        # Adjust the final linear
        if data_name == "CIFAR100":
            in_features = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_features, 10)

        self.model = model
        # Typically: [features, classifier]
        self.features = self.model.features
        self.avgpool = self.model.avgpool  # Adaptive pool
        self.classifier = self.model.classifier

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def split_edge_output(self, x, target_layer):
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        for i, layer in enumerate(self.features):
            if i <= target_layer:
                continue
            x = layer(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def EfficientNetB0(pretrained=True, data_name="ImageNet"):
    return EfficientNetB0Custom(pretrained=pretrained, data_name=data_name)

###############################################################################
# 7) RegNet (e.g. RegNet_X_8GF)
###############################################################################
class RegNetX8GCustom(nn.Module):
    """
    A wrapper for torchvision.models.regnet_x_8gf.
    The final classifier is model.fc.
    """

    def __init__(self, pretrained=True, data_name="ImageNet"):
        super(RegNetX8GCustom, self).__init__()
        if pretrained:
            model = models.regnet_x_8gf(weights=models.RegNet_X_8GF_Weights.IMAGENET1K_V2)
        else:
            model = models.regnet_x_8gf(weights=None)

        # Adjust final fc
        if data_name == "CIFAR100":
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, 100)
        elif data_name == "CIFAR10":
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, 10)

        self.model = model
        # If we do list(model.children()), we get [stem, trunk_output, fc].
        self.layers = list(self.model.children())[:-1]  # [stem, trunk_output]
        self.linear = list(self.model.children())[-1:]  # [fc]

    def forward(self, x):
        return self.model(x)

    def split_edge_output(self, x, target_layer):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == target_layer:
                return x
        return x

    def split_cloud_output(self, x, target_layer):
        for i, layer in enumerate(self.layers):
            if i <= target_layer:
                continue
            x = layer(x)
        x = torch.flatten(x, 1)
        x = self.linear[0](x)  # fc
        return x

def RegNetX8G(pretrained=True, data_name="ImageNet"):
    return RegNetX8GCustom(pretrained=pretrained, data_name=data_name)