import torch.nn as nn

from .unet_parts import DoubleConv, Down, OutConv, Up


class UNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        depth: int = 4,
        init_features: int = 64,
        norm_module=nn.BatchNorm2d,
    ):
        super(UNet, self).__init__()

        self.inc = DoubleConv(in_channels, init_features, norm_module=norm_module)

        self.downs = nn.ModuleList()
        features = init_features
        for _ in range(depth):
            self.downs.append(Down(features, features * 2, norm_module=norm_module))
            features *= 2

        self.ups = nn.ModuleList()
        for _ in range(depth):
            self.ups.append(
                Up(features, features // 2, bilinear=False, norm_module=norm_module)
            )
            features //= 2

        self.head = OutConv(features, out_channels)

    def forward(self, x):
        x = self.inc(x)

        X = [x]
        for down in self.downs:
            X.append(down(X[-1]))

        x = X.pop(-1)
        for up, x_skipco in zip(self.ups, X[::-1]):
            x = up(x, x_skipco)

        x = self.head(x)
        return x
