import torch.nn as nn
import torch
import torch.nn.functional as F

class MaskNet_K1(nn.Module):
    def __init__(self, k, in_planes=2, n_channel=4):
        super(MaskNet_K1, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, n_channel, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(n_channel, n_channel, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(n_channel, k, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = torch.sigmoid(self.conv3(out))
        return out
