import torch.nn as nn

from . import register_architecture
from ..components import get_component

@register_architecture("UNet")
class UNet(nn.Module):
    """
    UNet architecture with dimension parameter.
    
    Args:
        dimension: Dimension for convolution operations (1, 2, or 3)
        in_channels: Number of input channels
        out_channels: Number of output channels
        hidden_channels: Number of hidden channels
        block_class: Class name of the block to use for convolution
        depth: Number of down/up blocks
        bias: Whether to include bias in convolutions
        activation: Name of activation function
        **kwargs: Additional keyword arguments
    """
    def __init__(
            self,
            dimension: int,
            in_channels: int,
            out_channels: int,
            hidden_channels: int,
            block_class: str = "ResidualBlock",
            depth: int = 4,
            bias: bool = True,
            activation: str = "gelu",
            **kwargs
    ):
        super().__init__()  

        self.lift = get_component("LiftingLayer")(
            dimension=dimension,
            in_channels=in_channels,
            out_channels=hidden_channels,
            bias=bias
        )
        
        down = []
        for _ in range(depth):
            down.append(get_component("DownBlock")(
                dimension=dimension,
                in_channels=hidden_channels,
                out_channels=hidden_channels * 2,
                block_class=block_class,
                bias=bias,
                activation=activation,
                **kwargs
            ))
            hidden_channels *= 2
        self.down = nn.ModuleList(down)

        up = []
        for _ in range(depth):
            up.append(get_component("UpBlock")(
                dimension=dimension,
                in_channels=hidden_channels,
                out_channels=hidden_channels // 2,
                block_class=block_class,
                bias=bias,
                activation=activation,
                **kwargs
            ))
            hidden_channels //= 2
        self.up = nn.ModuleList(up)

        self.proj = get_component("ProjectionLayer")(
            dimension=dimension,
            in_channels=hidden_channels,
            out_channels=out_channels,
            bias=bias
        )
    
    def forward(self, x):
        x = self.lift(x)

        h = []
        for l in self.down:
            h.append(x)
            x = l(x)

        for l in self.up:
            x = l(x, h.pop())

        x = self.proj(x)
        return x