import math
import pdb
import random

import torch
from torch import nn
import torch.nn.functional as F


import math
import pdb
import random

import torch
from torch import nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, dim_list):
        super(Encoder, self).__init__()

        encoder = nn.ModuleList()
        for i in range(len(dim_list) - 1):
            if_start = (i == 0)
            if_last = (i == (len(dim_list) - 2))
            layer = self.build_layer(dim_list[i], dim_list[i + 1], if_start, if_last)
            encoder.append(layer)
        self.encoder = encoder

    def build_layer(self, dim_in, dim_out, if_start, if_last):
        if if_last:
            layer = nn.Sequential(
                nn.Linear(dim_in, dim_out),
                nn.LayerNorm(dim_out)
            )
        else:
            layer = nn.Sequential(
                nn.Linear(dim_in, dim_out),
                # nn.BatchNorm1d(dim_out),
                nn.LayerNorm(dim_out),
                nn.ReLU()
            )
        return layer

    def forward(self, x):
        for i, module in enumerate(self.encoder):
            x = module(x)
        return x


class Decoder(nn.Module):
    def __init__(self, dim_list):
        super(Decoder, self).__init__()

        decoder = nn.ModuleList()
        for i in range(len(dim_list) - 1):
            if_start = (i == 0)
            if_last = (i == (len(dim_list) - 2))
            layer = self.build_layer(dim_list[i], dim_list[i + 1], if_start, if_last)
            decoder.append(layer)
        self.decoder = decoder

    def build_layer(self, dim_in, dim_out, if_start, if_last):
        if if_last:
            layer = nn.Linear(dim_in, dim_out)
        else:
            layer = nn.Sequential(
                nn.Linear(dim_in, dim_out),
                nn.LayerNorm(dim_out),
                nn.ReLU() 
            )

        return layer

    def forward(self, x):
        for i, module in enumerate(self.decoder):
            x = module(x)
        return x
