import torch
import torch.nn as nn

from . import register_component, get_component
from .utils import CONV_TYPES

@register_component("UpBlock")
class UpBlock(nn.Module):
    """
    Upsampling block with dimension parameter.
    
    Args:
        dimension: Dimension for convolution operations (1, 2, or 3)
        in_channels: Number of input channels
        out_channels: Number of output channels
        block_class: Class name of the block to use for convolution
        **kwargs: Additional keyword arguments
    """
    def __init__(
            self,
            dimension: int,
            in_channels: int,
            out_channels: int,
            block_class: str = "ResidualBlock",
            **kwargs
    ):
        super().__init__()

        assert dimension in CONV_TYPES, "Dimension must be 1, 2, or 3"
        ConvTranspose = {
            1: nn.ConvTranspose1d,
            2: nn.ConvTranspose2d,
            3: nn.ConvTranspose3d,
        }[dimension]

        self.up = ConvTranspose(
            in_channels,
            in_channels // 2,
            kernel_size=2,
            stride=2,
            bias=kwargs.get('bias', True)
        )
        
        self.conv = get_component(block_class)(
            dimension=dimension,
            in_channels=in_channels,
            out_channels=out_channels,
            **kwargs
        )
    
    def forward(self, x1, x2):
        x = self.up(x1)
        x = self.conv(torch.cat([x2, x], dim=1))
        return x