import torch
import torch.nn as nn


class ResNetBlock(nn.Module):
    """ResNet architecure with flatten input."""

    def __init__(self, hidden_dim=512, inp_cond_dim=None):
        super().__init__()
        middle_dim = 2 * hidden_dim
        self.inp_cond_dim = inp_cond_dim

        self.ln1 = nn.LayerNorm(hidden_dim)
        self.mlp1 = nn.Linear(hidden_dim, middle_dim)

        self.ln2 = nn.LayerNorm(middle_dim)
        self.mlp2 = nn.Linear(middle_dim, hidden_dim)

        self.act = nn.GELU()

    def forward(self, x, cond=None):
        h = self.mlp1(self.act(self.ln1(x)))
        h = self.mlp2(self.act(self.ln2(h)))

        # residual connection
        h += x
        return h


class Encoder(nn.Module):
    """Encoder with ResNet architecure. Maps input to low-dim latent representation."""

    def __init__(
        self,
        vocab_size=20,
        seq_len=28,
        hidden_dim=512,
        num_blocks=1,
        conditional: bool = True,
        n_tokens_cond=None, #extra cond not implemented
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.seq_len = seq_len

        # init embedding
        self.mlp_x = nn.Linear(self.vocab_size * self.seq_len, hidden_dim)
        self.conditional = conditional
        if self.conditional:
            if n_tokens_cond is not None:
                self.mlp_xp = nn.Linear(n_tokens_cond * self.seq_len, hidden_dim)
            else:
                self.mlp_xp = nn.Linear(self.vocab_size * self.seq_len, hidden_dim)

        # resnet blocks
        self.blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.blocks.append(ResNetBlock(hidden_dim))

    def forward(self, x, cond=None):
        # TODO: add condition, pos embdding etc
        # flatten embedding
        x = x.reshape(x.shape[0], self.vocab_size * self.seq_len)
        if self.conditional and (cond is not None):
            cond = cond.reshape(cond.shape[0], cond.shape[1]*cond.shape[2])
            # forward individual encoder and add them
            
            out = self.mlp_x(x) + self.mlp_xp(cond)
        else:
            out = self.mlp_x(x)

        for block in self.blocks:
            out = block(out)
        return out

class Decoder(nn.Module):
    """Decoder with ResNet architecure.
    Maps latent low-dim representation back to full-sized input."""

    def __init__(
        self,
        vocab_size=20,
        seq_len=28,
        enc_hidden_dim=256,
        hidden_dim=256,
        num_blocks=1,
        inp_cond_dim=None,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.seq_len = seq_len

        # from bottleneck layer to decoder
        self.mlp = nn.Linear(enc_hidden_dim, hidden_dim)

        # resnet blocks
        self.blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.blocks.append(ResNetBlock(hidden_dim, inp_cond_dim=inp_cond_dim))

        # back to original size
        self.proj = nn.Linear(hidden_dim, self.vocab_size * self.seq_len)

    def forward(self, z):
        z = self.mlp(z)
        for block in self.blocks:
            z = block(z)
        logits = self.proj(z)
        logits = logits.reshape(logits.shape[0], self.seq_len, self.vocab_size)

        return logits

