import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from data.generate_data import generate_toy_data
from models import AntiSymm21Model
import numpy as np  # For std dev

# Device-agnostic setup
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Config
target_ns = [8, 16, 24, 32, 64]
source_n = 8    # original training value for n
num_test_tuples = 2000
batch_size = 20
model_dir = 'saved_models'
criterion = nn.MSELoss()

def evaluate_model(model, dataloader):
    model.to(device)
    model.eval()
    total_loss = 0.0
    total_samples = 0
    batch_losses = []
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            batch_mse = loss.item()
            batch_losses.append(batch_mse)
            total_loss += batch_mse * inputs.size(0)
            total_samples += inputs.size(0)

    avg_loss = total_loss / total_samples
    std_dev = np.std(batch_losses)
    rmse = torch.sqrt(torch.tensor(avg_loss))
    return avg_loss, rmse.item(), std_dev

def load_weights_from_source_n(target_model, model_class_name):
    """Load weights from the n=source_n version of the same model, only where names and shapes match."""
    source_path = os.path.join(model_dir, f"{model_class_name.lower()}({source_n})_diag.pth")
    if not os.path.exists(source_path):
        print(f"⚠️ Source weights for {model_class_name}({source_n}) not found.")
        return

    source_state = torch.load(source_path)
    target_state = target_model.state_dict()

    # Copy only matching keys and shapes
    compatible_weights = {k: v for k, v in source_state.items() if k in target_state and v.shape == target_state[k].shape}
    target_state.update(compatible_weights)
    target_model.load_state_dict(target_state)
    print(f"✅ Loaded weights from {model_class_name}({source_n})")

# Test models across different n
for n in target_ns:
    print(f"\n--- Evaluating models for n = {n} ---")
    X_test, y_test = generate_toy_data(num_test_tuples, n, seed=42)
    X_test, y_test = X_test.to(device), y_test.to(device)
    test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=batch_size, shuffle=False)

    # Evaluate AntiSymm21Model
    model = AntiSymm21Model(n)
    load_weights_from_source_n(model, "AntiSymm21Model")
    loss, rmse, std_dev = evaluate_model(model, test_loader)
    print(f"AntiSymm21Model({n}) - Test MSE: {loss:.6f}, RMSE: {rmse:.6f}, Std Dev (MSE): {std_dev:.6f}")
