import torch.nn as nn
from torchvision.models import resnet18
import torchvision.models as models
import torch
import torch.nn.functional as F
from models.resnet import ResNet18
from models.lenet import LeNet
from models.cnn import CNN
from models.vgg import VGG32, VGG224

def get_model(args, num_classes):
    match args.model:
        case 'ResNet18':
            model = ResNet18(num_classes)
            return model

        case 'LeNet':
            model = LeNet(num_classes)
            return model

        case 'CNN':
            model = CNN(num_classes)
            return model

        case 'VGG':
            if args.dataset_list == 'DIGIT10' or args.dataset_list == 'CIFAR10' or args.dataset_list == 'CIFAR100':
                model = VGG32(num_classes)
            elif args.dataset_list == 'OFFICE31' or args.dataset_list == 'OFFICE_CALTECH_10':
                model = VGG224(num_classes)
            return model
