import torch
import torch.nn as nn
from torch.nn import functional as F

# Base Model
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(channels),
            nn.GELU(),
            nn.Conv1d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

def create_blocks(in_channels, out_channels, select):
    layers = []

    if (select == 'down'):
        layers.append(nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1))
    elif (select == 'up'):
        layers.append(nn.ConvTranspose1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)) # ks - stride - 2 * padding + outpadding
    layers.append(ResBlock(out_channels))
    return nn.Sequential(*layers)

class Base_decoder(nn.Module):
    def __init__(self, latent_dim=4):
        super().__init__()

        self.fc = nn.Linear(latent_dim, 512)
        self.reshape_channels, self.reshape_length = 4, 128

        # sampling blocks
        self.down1 = create_blocks(self.reshape_channels, 64, 'down')
        self.down2 = create_blocks(64, 128, 'down')
        self.down3 = create_blocks(128, 256, 'down')
        self.up1 = create_blocks(256, 128, 'up')
        self.up2 = create_blocks(128, 64, 'up')
        self.up3 = create_blocks(64, 32, 'up')

        self.output_proj = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64),
            nn.GELU(),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.GELU(),

            nn.Conv1d(128, 1, kernel_size=1),
        )

        self.scale = nn.Parameter(torch.tensor(128.0)) # 8bits signed decimal, also learnable

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, z):
        x = F.gelu(self.fc(z)).view(-1, self.reshape_channels, self.reshape_length)
        
        # downsampling
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)    

        # upsampling
        u1 = F.gelu(self.up1(d3))
        u1_fused = u1 + d2
        u2 = F.gelu(self.up2(u1_fused))
        u2_fused = u2 + d1
        u3 = F.gelu(self.up3(u2_fused))

        out = self.output_proj(u3)  

        return out.squeeze(1) * self.scale