from torch import nn
from torch.nn import functional as F

class FCNHead(nn.Module):
    def __init__(self, in_channels: int, channels: int, input_shape) -> None:
        super().__init__()
        inter_channels = in_channels // 4
        self.input_shape = input_shape
        layers = [
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, channels, 1),
        ]
        self.layers = nn.Sequential(*layers)
    
    def reset_input_shape(self, input_shape):
        self.input_shape = input_shape
        
    def forward(self, x, input_shape=None):
        if input_shape is None:
            input_shape = self.input_shape
        # input_shape = (120, 160)
        x = self.layers(x)
        x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
        return x

