import torch
import torch.nn as nn

def get_num_groups(channels):
    # Ensure num_groups divides channels
    for num_groups in [32, 16, 8, 4, 2, 1]:
        if channels % num_groups == 0:
            return num_groups
    return 1

def exists(val):
    return val is not None

        
class ConvG(nn.Module):
    """(Convolution => [GroupNorm] => GELU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None, num_groups=None):
        super(ConvG, self).__init__()
        mid_channels = out_channels
        self.single_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(num_groups=1, num_channels=out_channels),
            nn.GELU()
        )

    def forward(self, x):
        return self.single_conv(x)

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2
        

class RecB(nn.Module): ## RB
    def __init__(self, dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(1, dim)         
        self.pw1   = nn.Conv2d(dim, dim, 1)      
        self.dw    = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.pw2   = nn.Conv2d(dim, dim*2, 1)     
        self.sg1   = SimpleGate()
        self.se    = nn.Sequential(               
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, dim, 1),
        )
        self.conv3 = nn.Conv2d(dim, dim, 3, padding=1)
        self.beta  = nn.Parameter(torch.zeros(1, dim, 1, 1))

        self.norm2 = nn.GroupNorm(1, dim)
        self.pw3   = nn.Conv2d(dim, dim*2, 1)
        self.sg2   = SimpleGate()
        self.pw4   = nn.Conv2d(dim, dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        y = self.norm1(x)

        y = self.pw1(y)
        y = self.dw(y)
        y = self.pw2(y)
        y = self.sg1(y)                 

        se = self.se(x)
        y = y * se

        y = self.conv3(y)
        y = x + self.beta * y          

        z = self.norm2(y)

        z = self.pw3(z)
        z = self.sg2(z)
        z = self.pw4(z)
        z = y + self.gamma * z

        return z

class DoubleConvLN(nn.Module):
    def __init__(self, in_ch, out_ch, num_groups=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(1, out_ch)
        self.act   = nn.GELU()

        self.rec   = RecB(out_ch)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm2 = nn.GroupNorm(1, out_ch)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act(x)

        x = self.rec(x)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.act(x)
        return x