import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet18


class ImportanceResnetImage(nn.Module):
    def __init__(self, in_dim=3, out_dim=196):
        super().__init__()
        self.model = resnet18()
        self.model.fc = nn.Linear(self.model.fc.in_features, out_dim)

    def forward(self, x):
        x = self.model(x)
        return F.sigmoid(x).unsqueeze(-1)
