import torch
import torch.nn as nn

from src.layers.residual import ResidualBlock


class PatchCNN(nn.Module):
    def __init__(self, out_dim):
        super(PatchCNN, self).__init__()

        # Residual blocks
        self.res_blocks = nn.Sequential(
            ResidualBlock(3, 32),
            ResidualBlock(32, 32),
            ResidualBlock(32, 32)
        )

        # Final processing
        self.final = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, out_dim)
        )

    def forward(self, x):
        x = self.res_blocks(x)
        x = self.final(x)
        return x
