from torch import nn, Tensor


class ResidualLayer(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(ResidualLayer, self).__init__()
        self.resblock = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)
        )

    def forward(self, input: Tensor) -> Tensor:
        return input + self.resblock(input)
# 编码器模块
class ResnetEncoder(nn.Module):
    def __init__(self, in_channels, hidden_dims, embedding_dim, residual_blocks):
        super().__init__()
        modules = []
        for h_dim in hidden_dims:
            modules.append(nn.Sequential(
                nn.Conv2d(in_channels, h_dim, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU()))
            in_channels = h_dim

        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU()))

        for _ in range(residual_blocks):
            modules.append(ResidualLayer(in_channels, in_channels))
        modules.append(nn.LeakyReLU())

        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, embedding_dim, kernel_size=1, stride=1),
            nn.LeakyReLU()))

        self.encoder = nn.Sequential(*modules)

    def forward(self, x):
        return self.encoder(x)

# 解码器模块
class ResnetDecoder(nn.Module):
    def __init__(self, embedding_dim, hidden_dims, residual_blocks):
        super().__init__()
        modules = []

        modules.append(nn.Sequential(
            nn.Conv2d(embedding_dim, hidden_dims[-1], kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU()))

        for _ in range(residual_blocks):
            modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))

        modules.append(nn.LeakyReLU())

        reversed_dims = hidden_dims[::-1]
        for i in range(len(reversed_dims) - 1):
            modules.append(nn.Sequential(
                nn.ConvTranspose2d(reversed_dims[i], reversed_dims[i + 1],
                                   kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU()))

        modules.append(nn.Sequential(
            nn.ConvTranspose2d(reversed_dims[-1], 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()))

        self.decoder = nn.Sequential(*modules)

    def forward(self, x):
        return self.decoder(x)