import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReLU(),
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return x + self.block(x)

class Encoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_res_blocks):
        super().__init__()
        self.fc = nn.Linear(in_dim, hidden_dim)
        self.res_blocks = nn.Sequential(*[ResBlock(hidden_dim) for _ in range(n_res_blocks)])
        self.out = nn.ReLU()

    def forward(self, x):
        x = self.fc(x)
        x = self.res_blocks(x)
        return self.out(x)

class Decoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_res_blocks):
        super().__init__()
        self.fc = nn.Linear(in_dim, hidden_dim)
        self.res_blocks = nn.Sequential(*[ResBlock(hidden_dim) for _ in range(n_res_blocks)])
        self.out = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.fc(x)
        x = self.res_blocks(x)
        return self.out(x)
