import torch
import torch.nn as nn

class SimpleUNet(nn.Module):
    def __init__(self, data_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(data_dim + 1, 64),  # 2D input + time
            nn.ReLU(),
            # nn.Linear(64, 128),
            # nn.ReLU(),
            # nn.Linear(128, 64),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, data_dim)  # 2D output
        )

    def forward(self, x, t):
        x = torch.cat([x, t.unsqueeze(-1)], dim=-1)
        return self.net(x)


if __name__ == "__main__":
    x = torch.randn(8, 2)               # 8 samples of 2D input
    t = torch.randint(0, 1000, (8,))    # 8 timesteps
    model = SimpleUNet()
    y = model(x, t)
    print(y.shape)  # Expected: (8, 2)