import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Identity, self).__init__()

    def forward(self, input):
        return input


class ResNet18Backbone(nn.Module):
    def __init__(self, pretrained=False, cifar=True):
        super(ResNet18Backbone, self).__init__()
        self.net = models.resnet18(pretrained=pretrained)

        self.net.fc = Identity()
        self.net.avgpool = Identity()

        self.output_size = 512 if cifar else 2048
        self.bn = nn.BatchNorm1d(self.output_size)

    def forward(self, x):
        return self.bn(self.net(x))


class ResNet18Classification(nn.Module):
    def __init__(self, cifar=True):
        super(ResNet18Classification, self).__init__()
        size = 512 if cifar else 2048
        self.net = nn.Linear(size, 1)
        self.input_size = size

    def forward(self, x):
        x = self.net(x)
        return F.sigmoid(x)

