import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, dilation):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(in_channels)

        self.conv2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(in_channels)

        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out += residual
        return out

class ResidualGroup(nn.Module):
    def __init__(self, in_channels, dilation, num_blocks=4):
        super(ResidualGroup, self).__init__()

        layers = [ResidualBlock(in_channels, dilation) for _ in range(num_blocks)]
        self.residual_blocks = nn.Sequential(*layers)

    def forward(self, x):
        return self.residual_blocks(x)
    
class encoder(nn.Module):
    def __init__(self, in_channels, dilations=[1,2,4,2,1], num_groups=5, channels=32):
        super(encoder, self).__init__()

        self.dilations = dilations
        self.num_groups = num_groups
        assert len(self.dilations) == self.num_groups

        self.conv1 = nn.Conv1d(in_channels, channels, kernel_size=1)

        self.skip_convs = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=1) for _ in range(self.num_groups)
        ])
        self.residual_groups = nn.ModuleList([
            ResidualGroup(channels, dilation) for dilation in self.dilations
        ])

        self.conv2 = nn.Conv1d(channels, channels, kernel_size=1)
        self.conv3 = nn.Conv1d(channels, 1, kernel_size=1)
        
    def forward(self, x):
        x = torch.transpose(x, 1, 2)
        x = self.conv1(x)

        skip_connections = []
        for _, (skip_conv, residual_group) in enumerate(zip(self.skip_convs, self.residual_groups)):
            skip_connections.append(skip_conv(x))
            x = residual_group(x)

        x = self.conv2(x)
        out = sum(skip_connections) + x
        out = self.conv3(out)
        
        return out