# test.py

import os
import torch
import torch.nn as nn
from data.generate_data import generate_toy_data
from models import AntiSymm21Model, PermEquiv21Model, SimpleMLP21

# Device-agnostic setup
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Parameters
n = 8
num_test_samples = 1000
seed = 42
model_dir = "saved_models"

# Load test data
X_test, y_test = generate_toy_data(num_test_samples, n, seed)
X_test, y_test = X_test.to(device), y_test.to(device)
criterion = nn.MSELoss()

# Define models and filenames
models_info = [
    (AntiSymm21Model(n), f"AntiSymm21Model({n})"),
    (PermEquiv21Model(n), f"PermEquiv21Model({n})"),
    (SimpleMLP21(n), f"SimpleMLP21({n})"),
]

# Evaluate each model
for model, name in models_info:
    filepath = os.path.join(model_dir, name + "_diag.pth")
    model.load_state_dict(torch.load(filepath))
    model.to(device)
    model.eval()

    with torch.no_grad():
        y_pred = model(X_test)
        loss = criterion(y_pred, y_test).item()

    print(f"[{name}] Test MSE: {loss:.4f}")
