import torch
import torch.nn as nn
import torch.nn.functional as F

class Embedding(nn.Module):
    def __init__(self, inp_dim, emb_dim):
        super().__init__()
        self.linear = nn.Linear(inp_dim, emb_dim)
    def forward(self, x):
        return self.linear(x)

class Unembedding(nn.Module):
    def __init__(self, emb_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(emb_dim, out_dim)
    def forward(self, x):
        return self.linear(x)

class Multiplyer(nn.Module):
    def __init__(self, in_dim, emb_dim, out_seq_len):
        super().__init__()
        self.linear = nn.Linear(in_dim, emb_dim * out_seq_len)
        self.out_seq_len = out_seq_len
        self.emb_dim = emb_dim
    def forward(self, x):
        b, d = x.shape
        x = self.linear(x)
        return x.view(b, self.out_seq_len, self.emb_dim)

class FeedForward(nn.Module):
    def __init__(self, emb_dim, mlp_dim):
        super().__init__()
        self.fc1 = nn.Linear(emb_dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, emb_dim)
    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x))) + x

class WeightedMultiheadAttention(nn.Module):
    """You must implement weighted multihead attention if your use-case requires, here we use vanilla."""
    def __init__(self, emb_dim, num_heads, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
    def forward(self, x, weights=None):
        out, _ = self.attn(x, x, x)
        return out

class EncoderBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, mlp_dim, dropout):
        super().__init__()
        self.attn = WeightedMultiheadAttention(emb_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.ff = FeedForward(emb_dim, mlp_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
    def forward(self, x, weights=None):
        x = self.attn(x, weights) + x
        x = self.norm1(x)
        x = self.ff(x)
        x = self.norm2(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, mlp_dim, dropout):
        super().__init__()
        self.attn = WeightedMultiheadAttention(emb_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.ff = FeedForward(emb_dim, mlp_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
    def forward(self, x):
        x = self.attn(x) + x
        x = self.norm1(x)
        x = self.ff(x)
        x = self.norm2(x)
        return x

class EncoderModel(nn.Module):
    def __init__(self, inp_dim, emb_dim, num_layers, num_heads, mlp_dim, dropout):
        super().__init__()
        self.embedding = Embedding(inp_dim, emb_dim)
        self.blocks = nn.ModuleList([
            EncoderBlock(emb_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)
        ])
        self.mlp_out = FeedForward(emb_dim, mlp_dim)
    def forward(self, x, weights=None):
        x = self.embedding(x)
        for block in self.blocks:
            x = block(x, weights)
        if weights is not None:
            # Weighted sum over sequence dim (like jnp.sum(x * weights) / sum(weights))
            ws = weights.unsqueeze(-1)  # (batch, seq, 1)
            x = (x * ws).sum(dim=1) / ws.sum(dim=1)
        else:
            x = x.mean(dim=1)
        x = self.mlp_out(x)
        return x

class DecoderModel(nn.Module):
    def __init__(self, emb_dim, out_dim, out_seq_len, num_layers, num_heads, mlp_dim, dropout):
        super().__init__()
        self.multiplyer = Multiplyer(emb_dim, emb_dim, out_seq_len)
        self.blocks = nn.ModuleList([
            DecoderBlock(emb_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)
        ])
        self.mlp_out = FeedForward(emb_dim, mlp_dim)
        self.norm = nn.LayerNorm(emb_dim)
        self.unembedding = Unembedding(emb_dim, out_dim)
    def forward(self, x):
        x = self.multiplyer(x)
        for block in self.blocks:
            x = block(x)
        x = self.mlp_out(x)
        x = self.norm(x)
        x = self.unembedding(x)
        return x


class TransformerAutoencoder(nn.Module):
    def __init__(self, config, seq_len, inp_dim):
        super().__init__()
        self.encoder = EncoderModel(
            inp_dim=inp_dim,
            emb_dim=config.emb_dim,
            num_layers=config.num_layers,
            num_heads=config.num_heads,
            mlp_dim=config.mlp_dim,
            dropout=config.attention_dropout_rate
        )
        self.decoder = DecoderModel(
            emb_dim=config.emb_dim,
            out_dim=inp_dim,
            out_seq_len=seq_len,
            num_layers=config.num_layers,
            num_heads=config.num_heads,
            mlp_dim=config.mlp_dim,
            dropout=config.attention_dropout_rate
        )
        self.scale_out = getattr(config, "scale_out", True)
        self.min_val = getattr(config, "min_val", -1)
        self.max_val = getattr(config, "max_val", 1)

    def forward(self, x, weights=None):
        enc = self.encoder(x, weights)
        dec = self.decoder(enc)
        if self.scale_out:
            dec = torch.sigmoid(dec) * (self.max_val - self.min_val) + self.min_val
        return enc, dec

    def compute_loss(self, x, weights=None, x_pair=None, enc_loss_func=None, dec_loss_func=None, alpha=1.0, beta=1.0):
        """
        Combine encoder loss and decoder loss.
        enc_loss_func: callable(enc, enc_pair)
        dec_loss_func: callable(x, dec)
        """
        enc, dec = self.forward(x, weights)
        enc_loss = 0.0
        if enc_loss_func is not None and x_pair is not None:
            enc_pair, _ = self.forward(x_pair, weights)
            enc_loss = enc_loss_func(enc, enc_pair)
        dec_loss = dec_loss_func(dec, x) if dec_loss_func is not None else F.mse_loss(dec, x)
        total_loss = alpha * enc_loss + beta * dec_loss
        return total_loss, enc_loss, dec_loss
