import torch
from torch import nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False): ##TODO: Allow different dropout rates.
        super().__init__()
        self.norm1 = nn.GroupNorm(1, in_channels)
        self.norm2 = nn.GroupNorm(1, out_channels)

        self.activation = nn.GELU()

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')

        if dropout:
            self.dropout = nn.Dropout(0.1)
        else:
            self.dropout = nn.Identity()

        if in_channels != out_channels:
            self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        h = self.norm1(x)
        h = self.activation(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = self.activation(h)
        h = self.dropout(h)
        h = self.conv2(h)
        return h + self.shortcut(x)
    
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, padding_mode='circular')
        self.conv = ConvBlock(in_channels, out_channels)
    
    def forward(self, x):
        x = self.down(x)
        return self.conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose1d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels, out_channels)

    def forward(self, x1, x2):
        x = self.up(x1)
        x = torch.cat([x2, x], dim=1)
        return self.conv(x)

class UNet1D_2(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, depth=4):
        super().__init__()
        self.lift = nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode='circular')

        down = []
        for _ in range(depth):
            down.append(Down(hidden_channels, hidden_channels * 2))
            hidden_channels *= 2
        self.down = nn.ModuleList(down)

        up = []
        for _ in range(depth):
            up.append(Up(hidden_channels, hidden_channels // 2))
            hidden_channels //= 2
        self.up = nn.ModuleList(up)

        self.proj = nn.Conv1d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')

    def forward(self, x):
        x = self.lift(x)

        h = []

        for l in self.down:
            h.append(x)
            x = l(x)

        for l in self.up:
            x = l(x, h.pop())

        return self.proj(x)
    
    