import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence, Union, List

class MLPResBlock(nn.Module):
    '''
    Gated residual MLP block (LayerNorm → SwiGLU → Dropout → Linear) with skip
    '''
    def __init__(self, dim: int, dropout: float = 0.1, expansion: int = 4):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.lin1 = nn.Linear(dim, dim * expansion)
        self.act = nn.SiLU()  # SwiGLU‑style non‑linearity
        self.lin2 = nn.Linear(dim * expansion, dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor): 
        h = self.lin2(self.drop(self.act(self.lin1(self.norm(x)))))
        return x + h  # residual connection


class AE(nn.Module):
    def __init__(
        self,
        in_dim: int,
        latent: int = 256,
        hidden_dims: Union[int, Sequence[int]] = (512, 512),
        dropout: float = 0.1,
        tie_weights: bool = False,
        residual_depth: int = 1,
        latent_l2: float = 1e-4, 
    ):
        super().__init__()

        # Encoder
        if isinstance(hidden_dims, int):
            hidden_dims = [hidden_dims]
        hidden_dims = list(hidden_dims)
        enc_layers: List[nn.Module] = []
        prev = in_dim
        for h in hidden_dims:
            enc_layers.extend(
                [
                    nn.Linear(prev, h),
                    nn.LayerNorm(h),
                    nn.SiLU(),
                    nn.Dropout(dropout),
                ]
            )
            enc_layers.extend(MLPResBlock(h, dropout) for _ in range(residual_depth))
            prev = h
        self.enc_to_latent = nn.Linear(prev, latent, bias=True)
        self.encoder = nn.Sequential(*enc_layers, self.enc_to_latent)

        # Decoder 
        dec_layers: List[nn.Module] = []
        prev = latent
        for h in reversed(hidden_dims):
            dec_layers.extend(
                [
                    nn.Linear(prev, h),
                    nn.LayerNorm(h),
                    nn.SiLU(),
                    nn.Dropout(dropout),
                ]
            )
            dec_layers.extend(MLPResBlock(h, dropout) for _ in range(residual_depth))
            prev = h
        self.decoder_hidden = nn.Sequential(*dec_layers)

        self.tie_weights = tie_weights
        if not tie_weights:
            self.decoder_out = nn.Linear(prev, in_dim, bias=True)
        else:
            # Bias-only layer for tied-weight decoding
            self.decoder_bias = nn.Linear(latent, in_dim, bias=True)

        self.latent_l2 = latent_l2  # latent regularisation weight

    def forward(self, x: torch.Tensor, *, normalise_latent: bool = True):
        z = self.encoder(x)
        if normalise_latent:
            z = z / (z.norm(dim=-1, keepdim=True) + 1e-8)

        h = self.decoder_hidden(z)

        if self.tie_weights:
            # Use encoder weights and decoder bias
            x_hat = F.linear(h, self.enc_to_latent.weight.t()) + self.decoder_bias.bias
        else:
            x_hat = self.decoder_out(h)

        return z, x_hat

    def reg_loss(self) -> torch.Tensor:
        '''Optional latent weight decay with coefficient latent_l2'''
        if self.latent_l2 <= 0:
            return torch.zeros(1, device=next(self.parameters()).device)
        return self.latent_l2 * self.enc_to_latent.weight.pow(2).mean()
