import torch.nn as nn
import torchvision.models as models


def resnet50(config, num_classes, **kwargs):
    model = models.resnet50(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Linear(2048, num_classes)
    model.fc.requires_grad = True

    return model


def resnet18(config, num_classes, **kwargs):
    model = models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Linear(512, num_classes)
    model.fc.requires_grad = True

    return model

