from torchvision import models
import torch.nn as nn

class ResNet18(nn.Module):
    def __init__(self, dim_output: int):
        super().__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, dim_output)
        
    def forward(self, x):
        return self.resnet(x)