import torch
import torch.nn as nn

from . import register_component, get_activation
from .utils import CONV_TYPES

@register_component("ConvNextBlock")
class ConvNextBlock(nn.Module):
    """
    ConvNext block.

    Args:
        dimension: Dimension for convolution operations (1, 2, or 3)
        in_channels: Number of input channels
        out_channels: Number of output channels
        kernel_size: Size of the convolutional kernel
        padding_mode: Padding mode for the convolutional kernel
        norm: Whether to include normalization layers
        dropout_rate: Dropout rate
        bias: Whether to include bias in convolutions
        activation: Name of activation function (e.g., "relu", "gelu")
        **kwargs: Additional keyword arguments
    """
    def __init__(
            self, 
            dimension: int,
            in_channels: int, 
            out_channels: int, 
            kernel_size: int = 7,
            padding_mode: str = "circular",
            norm: bool = True,
            dropout_rate: float = 0.0,
            bias: bool = True,
            activation: str = "gelu",
            **kwargs
        ):
        super().__init__()

        assert dimension in CONV_TYPES, "Dimension must be 1, 2, or 3"
        Conv = CONV_TYPES[dimension]

        padding = int((kernel_size - 1)/2) # Auto-padding

        self.conv1 = Conv(
            in_channels, 
            out_channels, 
            kernel_size, 
            padding=padding, 
            padding_mode=padding_mode, 
            bias=bias
        )
        
        self.pointwise_conv1 = Conv(
            out_channels, 
            out_channels * 4, 
            kernel_size=1,  
            bias=bias
        )
        
        self.pointwise_conv2 = Conv(
            out_channels * 4, 
            out_channels, 
            kernel_size=1,  
            bias=bias
        )

        if not norm:
            self.norm1 = nn.Identity()
            self.norm2 = nn.Identity()
        else:
            self.norm1 = nn.GroupNorm(1, out_channels)
            self.norm2 = GRN(out_channels * 4)

        self.activation = get_activation(activation, **kwargs)

        if dropout_rate == 0.0:
            self.dropout = nn.Identity()
        else:
            self.dropout = nn.Dropout(dropout_rate)

        if in_channels == out_channels:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = Conv(
                in_channels, 
                out_channels, 
                kernel_size=1,  
                bias=False
            )
    
    def forward(self, x):
        skip = x 
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.pointwise_conv1(x)
        x = self.activation(x)
        x = self.norm2(x)
        x = self.pointwise_conv2(x)
        x = self.dropout(x)
        x = x + self.shortcut(skip)
        return x

class GRN(nn.Module):
    """
    Global Response Normalization.
    """
    def __init__(
            self, 
            in_channels: int
        ):
        super().__init__()

        self.in_channels = in_channels

        self.gamma = nn.Parameter(torch.ones(in_channels)) 
        self.beta = nn.Parameter(torch.zeros(in_channels)) 
        
    def forward(self, x): 

        shape = [1, self.in_channels] + [1] * (x.dim() - 2)

        Gx = torch.norm(x, p=2, dim=1, keepdim=True) 
        Nx = Gx / (torch.mean(Gx, dim=1, keepdim=True) + 1e-6) 
            
        return x + self.gamma.view(*shape) * (x * Nx) + self.beta.view(*shape)