'''ResNet18/34/50/101/152 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models


def ResNet18(num_classes=10):
    model = models.resnet18(pretrained=True)
    # Freeze all the pre-trained layers except the last fc layer and the last conv layer
    for param in model.named_parameters():
        if param[0].startswith("layer4"):
            continue
        param[1].requires_grad = False
    # change the last fc layer
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    # make fc trainable
    for param in model.fc.parameters():
        param.requires_grad = True
    return model

def ResNet34(num_classes=10):
    model = models.resnet34(pretrained=True)
    # Freeze all the pre-trained layers except the last fc layer and the last conv layer
    for param in model.named_parameters():
        if param[0].startswith("layer4"):
            continue
        param[1].requires_grad = False
    # change the last fc layer
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    # make fc trainable
    for param in model.fc.parameters():
        param.requires_grad = True
    return model

def test_resnet():
    net = ResNet18()
    y = net(Variable(torch.randn(1,3,224,224)))
    print(y.size())


if __name__ == "__main__":
    test_resnet()