from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class ResBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x: Tensor) -> Tensor:
        identity = self.skip(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + identity)


class ResNetEncoder(nn.Module):
    def __init__(self, latent_dim=20):
        super().__init__()
        self.conv = nn.Sequential(
            ResBlock(1, 32, stride=2),
            ResBlock(32, 64, stride=2),
            ResBlock(64, 128, stride=2),
            nn.AdaptiveAvgPool2d(1),
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_log_var = nn.Linear(128, latent_dim)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        x = self.conv(x)
        x = torch.flatten(x, 1)
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)
        return mu, log_var


class ResNetDecoder(nn.Module):
    def __init__(self, latent_dim=20):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 128 * 7 * 7)  # Targeting 7x7 spatial map
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (128, 7, 7)),  # (batch, 128*7*7) → (batch, 128, 7, 7)
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # → 14x14
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # → 28x28
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),  # → keep 28x28
            nn.Sigmoid(),
        )

    def forward(self, z):
        x = self.fc(z)
        x = self.decoder(x)
        return x


class JaffeEncoder(nn.Module):
    def __init__(self, latent_dim=10):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),  # 64x64 → 64x64
            nn.ReLU(),
            ResBlock(64, 128, stride=2),  # 64x64 → 32x32
            ResBlock(128, 256, stride=2),  # 32x32 → 16x16
            ResBlock(256, 512, stride=2),  # 16x16 → 8x8
            ResBlock(512, 512, stride=2),  # 8x8 → 4x4
        )
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_log_var = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        x = self.conv(x)
        x = torch.flatten(x, 1)
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)
        return mu, log_var


class JaffeDecoder(nn.Module):
    def __init__(self, latent_dim=10):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)  # Targeting 4x4 spatial map
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),  # → 8x8
            ResBlock(512, 512),
            nn.Upsample(scale_factor=2, mode="nearest"),  # → 16x16
            ResBlock(512, 256),
            nn.Upsample(scale_factor=2, mode="nearest"),  # → 32x32
            ResBlock(256, 128),
            nn.Upsample(scale_factor=2, mode="nearest"),  # → 64x64
            ResBlock(128, 64),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),  # → 64x64
            nn.Tanh(),
        )

    def forward(self, z):
        x = self.fc(z).view(-1, 512, 4, 4)
        x = self.decoder(x)
        return x
