'''Modified from https://github.com/alinlab/LfF and https://github.com/kakaoenterprise/Learning-Debiased-Disentangled'''

import torch.nn as nn
from torchvision.models import resnet50, resnet18

def get_model(model_tag, num_classes):
    if model_tag == "ResNet18":
        print('bringing no pretrained resnet18 ...')
        model = resnet18(pretrained=False)
        model.fc = nn.Linear(512, num_classes)
        return model.cuda()
    elif model_tag == "ResNet50":
        print('bringing no pretrained resnet50 ...')
        model = resnet50(pretrained=False)
        model.fc = nn.Linear(2048, 2)
        return model.cuda()
    elif model_tag == 'resnet18_DISENTANGLE':
        print('bringing no pretrained resnet18 disentangle...')
        model = resnet18(pretrained=False)
        model.fc = nn.Linear(1024, num_classes)
        return model
    elif model_tag == 'resnet50_DISENTANGLE':
        print('bringing no pretrained resnet18 disentangle...')
        model = resnet50(pretrained=False)
        model.fc = nn.Linear(4096, num_classes)
        return model
    else:
        raise NotImplementedError
