import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, 3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, 3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, 1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class AutoEncoder(nn.Module):

    def __init__(
        self,
        nu=2.1,
        input_shape=1,
        latent_dim=32,
        reconstruction="l1",
        mode="cnn",
        device="mlp" if torch.cuda.is_available() else "cpu",
        normalization="batchnorm",
        activation="relu",
        dropout_rate=0.0,
    ):
        super(AutoEncoder, self).__init__()
        self.name = "AutoEncoder"

        self.nu = nu
        self.in_channels = input_shape[0]
        try:
            self.height = input_shape[1]
            self.width = input_shape[2]
        except Exception as e:
            self.height = 1
            self.width = 1
        self.input_dim = self.in_channels * self.height * self.width

        self.latent_dim = latent_dim
        self.device = device
        self.normalization = normalization
        self.activation = activation
        self.dropout_rate = dropout_rate

        if reconstruction == "mae" or reconstruction == "l1":
            self.reconstruction = F.l1_loss
        elif reconstruction == "mse" or reconstruction == "l2":
            self.reconstruction = F.mse_loss
        else:
            raise ValueError(
                f"Unsupported reconstruction loss: {reconstruction}"
            )

        # Set activation function
        if activation == "relu":
            self.activation_fn = nn.ReLU()
        elif activation == "leaky_relu":
            self.activation_fn = nn.LeakyReLU()
        elif activation == "elu":
            self.activation_fn = nn.ELU()
        elif activation == "selu":
            self.activation_fn = nn.SELU()
        elif activation == "gelu":
            self.activation_fn = nn.GELU()
        elif activation == "swish":
            self.activation_fn = nn.SiLU()
        elif activation == "tanh":
            self.activation_fn = nn.Tanh()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        self.fc1 = nn.Linear(2 * latent_dim, latent_dim)
        self.fc2 = nn.Linear(2 * latent_dim, latent_dim)

        if mode == "cnn":
            self._build_cnn_model()
        elif mode == "mlp":
            self._build_mlp_model()
        else:
            raise ValueError(f"Unsupported mode: {mode}")

    def _get_normalization_layer(self, num_features, is_conv=True):
        """Create normalization layer based on the specified type"""
        if self.normalization == "batchnorm":
            if is_conv:
                return nn.BatchNorm2d(num_features)
            else:
                return nn.BatchNorm1d(num_features)
        elif self.normalization == "layernorm":
            return nn.LayerNorm(num_features)
        elif self.normalization == "instancenorm":
            if is_conv:
                return nn.InstanceNorm2d(num_features)
            else:
                return nn.InstanceNorm1d(num_features)
        elif self.normalization == "groupnorm":
            # Use 8 groups by default, adjust as needed
            num_groups = min(8, num_features)
            return nn.GroupNorm(num_groups, num_features)
        elif self.normalization == "none":
            return nn.Identity()
        else:
            raise ValueError(
                f"Unsupported normalization: {self.normalization}"
            )

    def _get_dropout_layer(self):
        """Create dropout layer if dropout_rate > 0"""
        if self.dropout_rate > 0:
            return nn.Dropout(self.dropout_rate)
        else:
            return nn.Identity()

    def forward(self, x):
        z, _ = self.encode(x)
        # Decoder
        x_hat = self.decode(z)
        return z, x_hat, (None, None)

    def encode(self, x):
        x = self.encoder(x)
        z = self.enc_to_latent1(x)
        return z, (None, None)

    def decode(self, z):
        return self.decoder(z)

    def regularization(self, derivatives):
        return torch.tensor(0.0).to(self.device)

    def generate(self, n_gen=64):
        """Generate samples by decoding random latent vectors"""
        with torch.no_grad():
            z = torch.randn(n_gen, self.latent_dim).to(self.device)
            generated = self.decode(z)
        return generated

    def _build_cnn_model(self):
        norm_layers = {
            "norm_1024": self._get_normalization_layer(1024, is_conv=True),
            "norm_512": self._get_normalization_layer(512, is_conv=True),
            "norm_256": self._get_normalization_layer(256, is_conv=True),
            "norm_128": self._get_normalization_layer(128, is_conv=True),
            "norm_64": self._get_normalization_layer(64, is_conv=True),
            "norm_32": self._get_normalization_layer(32, is_conv=True),
            "norm_16": self._get_normalization_layer(16, is_conv=True),
        }

        if self.height == 64 and self.width == 64:
            self.encoder = nn.Sequential(
                nn.Conv2d(self.in_channels, 32, 3, stride=1, padding=1),
                self._get_normalization_layer(32, is_conv=True),
                self.activation_fn,
                self._get_dropout_layer(),
                ResidualBlock(32, 32),
                ResidualBlock(32, 64, stride=2),  # 64x64 -> 32x32
                ResidualBlock(64, 64),
                ResidualBlock(64, 128, stride=2),  # 32x32 -> 16x16
                ResidualBlock(128, 128),
                ResidualBlock(128, 256, stride=2),  # 16x16 -> 8x8
                ResidualBlock(256, 256),
                ResidualBlock(256, 256, stride=2),  # 8x8 -> 4x4
                nn.Flatten(),
            )
            self.enc_to_latent1 = nn.Linear(256 * 4 * 4, self.latent_dim)
            self.enc_to_latent2 = nn.Linear(256 * 4 * 4, self.latent_dim)

            class Decoder64x64(nn.Module):
                def __init__(
                    self, latent_dim, in_channels, norm_layers, activation_fn
                ):
                    super().__init__()
                    self.latent_to_dec = nn.Linear(latent_dim, 256 * 4 * 4)
                    self.activation_fn = activation_fn
                    self.decoder_blocks = nn.Sequential(
                        ResidualBlock(256, 256),
                        nn.ConvTranspose2d(
                            256, 256, 3, stride=2, padding=1, output_padding=1
                        ),  # 4x4 -> 8x8
                        norm_layers["norm_256"],
                        activation_fn,
                        ResidualBlock(256, 256),
                        nn.ConvTranspose2d(
                            256, 128, 3, stride=2, padding=1, output_padding=1
                        ),  # 8x8 -> 16x16
                        norm_layers["norm_128"],
                        activation_fn,
                        ResidualBlock(128, 128),
                        nn.ConvTranspose2d(
                            128, 64, 3, stride=2, padding=1, output_padding=1
                        ),  # 16x16 -> 32x32
                        norm_layers["norm_64"],
                        activation_fn,
                        ResidualBlock(64, 64),
                        nn.ConvTranspose2d(
                            64, 32, 3, stride=2, padding=1, output_padding=1
                        ),  # 32x32 -> 64x64
                        norm_layers["norm_32"],
                        activation_fn,
                        ResidualBlock(32, 32),
                        nn.Conv2d(32, in_channels, 3, stride=1, padding=1),
                        nn.Sigmoid(),
                    )

                def forward(self, z):
                    batch_size = z.shape[0]
                    out = self.latent_to_dec(z)
                    out = out.view(batch_size, 256, 4, 4)
                    return self.decoder_blocks(out)

            self.decoder = Decoder64x64(
                self.latent_dim,
                self.in_channels,
                norm_layers,
                self.activation_fn,
            )

        elif self.height == 128 and self.width == 128:
            print("Building 128x128 AutoEncoder model...")
            self.encoder = nn.Sequential(
                nn.Conv2d(self.in_channels, 32, 3, stride=1, padding=1),
                self._get_normalization_layer(32, is_conv=True),
                self.activation_fn,
                self._get_dropout_layer(),
                ResidualBlock(32, 32),
                ResidualBlock(32, 64, stride=2),  # 128x128 -> 64x64
                ResidualBlock(64, 64),
                ResidualBlock(64, 128, stride=2),  # 64x64 -> 32x32
                ResidualBlock(128, 128),
                ResidualBlock(128, 256, stride=2),  # 32x32 -> 16x16
                ResidualBlock(256, 256),
                ResidualBlock(256, 512, stride=2),  # 16x16 -> 8x8
                ResidualBlock(512, 512),
                ResidualBlock(512, 512, stride=2),  # 8x8 -> 4x4
                nn.Flatten(),
            )
            self.enc_to_latent1 = nn.Linear(512 * 4 * 4, self.latent_dim)
            self.enc_to_latent2 = nn.Linear(512 * 4 * 4, self.latent_dim)

            class Decoder128x128(nn.Module):
                def __init__(
                    self, latent_dim, in_channels, norm_layers, activation_fn
                ):
                    super().__init__()
                    self.latent_to_dec = nn.Linear(latent_dim, 512 * 4 * 4)
                    self.activation_fn = activation_fn
                    self.decoder_blocks = nn.Sequential(
                        ResidualBlock(512, 512),
                        nn.ConvTranspose2d(
                            512, 512, 3, stride=2, padding=1, output_padding=1
                        ),  # 4x4 -> 8x8
                        norm_layers["norm_512"],
                        activation_fn,
                        ResidualBlock(512, 512),
                        nn.ConvTranspose2d(
                            512, 256, 3, stride=2, padding=1, output_padding=1
                        ),  # 8x8 -> 16x16
                        norm_layers["norm_256"],
                        activation_fn,
                        ResidualBlock(256, 256),
                        nn.ConvTranspose2d(
                            256, 128, 3, stride=2, padding=1, output_padding=1
                        ),  # 16x16 -> 32x32
                        norm_layers["norm_128"],
                        activation_fn,
                        ResidualBlock(128, 128),
                        nn.ConvTranspose2d(
                            128, 64, 3, stride=2, padding=1, output_padding=1
                        ),  # 32x32 -> 64x64
                        norm_layers["norm_64"],
                        activation_fn,
                        ResidualBlock(64, 64),
                        nn.ConvTranspose2d(
                            64, 32, 3, stride=2, padding=1, output_padding=1
                        ),  # 64x64 -> 128x128
                        norm_layers["norm_32"],
                        activation_fn,
                        ResidualBlock(32, 32),
                        nn.Conv2d(32, in_channels, 3, stride=1, padding=1),
                        nn.Sigmoid(),
                    )

                def forward(self, z):
                    batch_size = z.shape[0]
                    out = self.latent_to_dec(z)
                    out = out.view(batch_size, 512, 4, 4)
                    return self.decoder_blocks(out)

            self.decoder = Decoder128x128(
                self.latent_dim,
                self.in_channels,
                norm_layers,
                self.activation_fn,
            )
        else:
            self.encoder = nn.Sequential(
                nn.Conv2d(self.in_channels, 32, 3, stride=1, padding=1),
                self._get_normalization_layer(32, is_conv=True),
                self.activation_fn,
                self._get_dropout_layer(),
                ResidualBlock(32, 32),
                ResidualBlock(32, 64, stride=2),  # 32x32 -> 16x16
                ResidualBlock(64, 64),
                ResidualBlock(64, 128, stride=2),  # 16x16 -> 8x8
                ResidualBlock(128, 128),
                ResidualBlock(128, 128, stride=2),  # 8x8 -> 4x4
                nn.Flatten(),
            )
            self.enc_to_latent1 = nn.Linear(128 * 4 * 4, self.latent_dim)
            self.enc_to_latent2 = nn.Linear(128 * 4 * 4, self.latent_dim)

            class Decoder32x32(nn.Module):
                def __init__(self, latent_dim, in_channels, norm_layers):
                    super().__init__()
                    self.latent_to_dec = nn.Linear(latent_dim, 128 * 4 * 4)
                    self.decoder_blocks = nn.Sequential(
                        ResidualBlock(128, 128),
                        nn.ConvTranspose2d(
                            128,
                            128,
                            3,
                            stride=2,
                            padding=1,
                            output_padding=1,
                        ),  # 4x4 -> 8x8
                        norm_layers["norm_128"],
                        ResidualBlock(128, 128),
                        nn.ConvTranspose2d(
                            128,
                            64,
                            3,
                            stride=2,
                            padding=1,
                            output_padding=1,
                        ),  # 8x8 -> 16x16
                        norm_layers["norm_64"],
                        ResidualBlock(64, 64),
                        nn.ConvTranspose2d(
                            64,
                            32,
                            3,
                            stride=2,
                            padding=1,
                            output_padding=1,
                        ),  # 16x16 -> 32x32
                        norm_layers["norm_32"],
                        ResidualBlock(32, 32),
                        nn.Conv2d(32, in_channels, 3, stride=1, padding=1),
                        nn.Sigmoid(),
                    )

                def forward(self, z):
                    batch_size = z.shape[0]
                    out = self.latent_to_dec(z)
                    out = out.view(batch_size, 128, 4, 4)
                    return self.decoder_blocks(out)

            self.decoder = Decoder32x32(
                self.latent_dim, self.in_channels, norm_layers
            )

    def _build_mlp_model(self):
        # Encoder: simple MLP for text with configurable normalization
        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, self.latent_dim * 4),
            self._get_normalization_layer(self.latent_dim * 4, is_conv=False),
            self.activation_fn,
            self._get_dropout_layer(),
            nn.Linear(self.latent_dim * 4, self.latent_dim * 4),
            self._get_normalization_layer(self.latent_dim * 4, is_conv=False),
        )

        self.enc_to_latent1 = nn.Linear(self.latent_dim * 4, self.latent_dim)
        self.enc_to_latent2 = nn.Linear(self.latent_dim * 4, self.latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, self.latent_dim * 4),
            self._get_normalization_layer(self.latent_dim * 4, is_conv=False),
            self.activation_fn,
            self._get_dropout_layer(),
            nn.Linear(self.latent_dim * 4, self.input_dim),
        )
