from libs import *
import torchvision.models as models
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models import (
    mobilenet_v2,
    mobilenet_v3_large,
    MobileNet_V2_Weights,
    MobileNet_V3_Large_Weights,
)
from torchvision.models import shufflenet_v2_x1_0, ShuffleNet_V2_X1_0_Weights
from torchvision.models import vgg16, vgg11, VGG16_Weights, VGG11_Weights
from torchvision.models.resnet import BasicBlock


class client_model(nn.Module):
    def __init__(self, name, args=True):
        super(client_model, self).__init__()
        self.name = name

        if self.name == "cifar10_resnet18":
            self.model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
            self.model.fc = nn.Linear(self.model.fc.in_features, 10)  # 修改分类头为10类
            # 冻结所有层，仅训练分类头
            for param in self.model.parameters():
                param.requires_grad = False
            for param in self.model.fc.parameters():
                param.requires_grad = True

        if self.name == "cifar100_VGG16":
            self.model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
            self.model.classifier[6] = nn.Linear(4096, 100)  # 修改分类头为100类
            # 冻结所有层，仅训练分类头
            for param in self.model.parameters():
                param.requires_grad = False
            for param in self.model.classifier[6].parameters():
                param.requires_grad = True

        if self.name == "tinyimagenet_MobileNet":

            self.model = mobilenet_v3_large(
                weights=MobileNet_V3_Large_Weights.IMAGENET1K_V2
            )
            self.model.classifier[-1] = nn.Linear(
                self.model.classifier[-1].in_features, 200
            )

            for param in self.model.parameters():
                param.requires_grad = False
            for param in self.model.classifier.parameters():
                param.requires_grad = True
            # for param in self.model.features[-3:].parameters():
            #     param.requires_grad = True


    def forward(self, x):

        if self.name == "cifar10_resnet18":
            x = self.model(x)

        if self.name == "cifar100_VGG16":
            x = self.model(x)

        if self.name == "tinyimagenet_MobileNet":
            x = self.model(x)

        return x
