import torch
from torch import nn, Tensor

# Pretrained model for toy datasets
class ToyMLP(nn.Module):
    def __init__(
        self, vocab_size: int = 16, hidden_dim=256, length=2, time_dim=1):
        super().__init__()
        self.length = length
        self.time_dim = time_dim
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim

        self.block = nn.Sequential(
            nn.ReLU(),
            nn.Linear(length + self.time_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, vocab_size * length)
        )
    def forward(self, x, t):
        x = x / self.vocab_size
        if self.time_dim == 1:
            t = t.unsqueeze(-1) # shape: (B, time_dim)
            h = torch.cat([x, t], dim=1)
        else:
            h = x
        h = self.block(h)
        h = h.view(-1, self.length, self.vocab_size)
        return h
