import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18


class my_backbone(nn.Module):
    def __init__(self):
        super().__init__()
        model = resnet18(pretrained=True)
        self.output_dim = model.fc.in_features
        self.backbone = nn.Sequential(*list(model.children())[:-1])
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, 
                               kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.backbone(x).reshape(x.shape[0], -1)
        return x


if __name__ == '__main__':
    model = my_backbone()
    x = torch.rand(10, 1, 28, 28)
    y = model(x)
    print(y.shape, model.output_dim)

