import torch.nn as nn

from . import register_component, get_activation
from .utils import CONV_TYPES

@register_component("ResidualBlock")
class ResidualBlock(nn.Module):
    """
    Residual block.

    Args:
        dimension: Dimension for convolution operations (1, 2, or 3)
        in_channels: Number of input channels
        out_channels: Number of output channels
        kernel_size: Size of the convolutional kernel
        padding_mode: Padding mode for the convolutional kernel
        norm: Whether to include normalization layers
        dropout_rate: Dropout rate
        bias: Whether to include bias in convolutions
        activation: Name of activation function (e.g., "relu", "gelu")
        **kwargs: Additional keyword arguments
    """
    def __init__(
            self, 
            dimension: int,
            in_channels: int, 
            out_channels: int, 
            kernel_size: int = 3,
            padding_mode: str = "circular",
            norm: bool = True,
            dropout_rate: float = 0.0,
            bias: bool = True,
            activation: str = "gelu",
            **kwargs
        ):
        super().__init__()

        assert dimension in CONV_TYPES, "Dimension must be 1, 2, or 3"
        Conv = CONV_TYPES[dimension]

        padding = int((kernel_size - 1)/2) # Auto-padding

        self.conv1 = Conv(
            in_channels, 
            out_channels, 
            kernel_size, 
            padding=padding, 
            padding_mode=padding_mode, 
            bias=bias
        )
        
        self.conv2 = Conv(
            out_channels, 
            out_channels, 
            kernel_size, 
            padding=padding, 
            padding_mode=padding_mode, 
            bias=bias
        )

        if not norm:
            self.norm1 = nn.Identity()
            self.norm2 = nn.Identity()
        else:
            self.norm1 = nn.GroupNorm(1, out_channels)
            self.norm2 = nn.GroupNorm(1, out_channels)

        # Activation keyword handling is delegated to utils.get_activation
        self.activation = get_activation(activation, **kwargs)

        if dropout_rate == 0.0:
            self.dropout = nn.Identity()
        else:
            self.dropout = nn.Dropout(dropout_rate)

        if in_channels == out_channels:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = Conv(
                in_channels, 
                out_channels, 
                kernel_size=1,  
                bias=False
            )
    
    def forward(self, x):
        skip = x 
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activation(x)
        x = x + self.shortcut(skip)
        return x