from torch import nn
from torch.nn import functional as F


class ImportanceCNN(nn.Module):
    def __init__(self, embedding_size, out_dim):
        super().__init__()
        self.out_dim = out_dim

        self.final = nn.Sequential(
            nn.Conv2d(embedding_size, embedding_size, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(embedding_size, 1, kernel_size=3, stride=1, padding=1),
            nn.Flatten(),
        )

    def forward(self, x):
        x = self.final(x)
        return F.sigmoid(x).unsqueeze(-1)
