import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
import numpy as np
import pandas as pd
import math
import random

from test_core import ModelWrapper, train

def run_experiment(model_type, hidden_dim, seed, epochs=14, device="cpu"):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    params = {
        "mode": model_type,
        "input_dim": 784,
        "hidden_dim": hidden_dim,
        "num_classes": 10,
        "beta": 1 / math.sqrt(hidden_dim),
        "num_states": 1,
        "num_memories": 64,
    }

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_set = datasets.MNIST("./datasets", train=True, download=True, transform=transform)
    test_set = datasets.MNIST("./datasets", train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)

    model = ModelWrapper(**params).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

    class Args:
        log_interval = 100
        dry_run = False
    args = Args()

    for epoch in range(1, epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        scheduler.step()

    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = 100.0 * correct / len(test_loader.dataset)
    return accuracy


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    models = ["kf_attention", "hf_attention"]
    hidden_dims = [4, 8, 16, 32, 64]
    num_trials = 5
    epochs = 14

    results = []

    for h_dim in hidden_dims:
        for model in models:
            trial_accs = []

            print("\n" + "#" * 70)
            print(f"Evaluating Model={model}, hidden_dim={h_dim}")
            print("#" * 70)

            for t in range(num_trials):
                seed = 42 + t
                print(f"Trial {t + 1}/{num_trials} (seed={seed})")

                acc = run_experiment(
                    model_type=model,
                    hidden_dim=h_dim,
                    seed=seed,
                    epochs=epochs,
                    device=device,
                )
                trial_accs.append(acc)
                print(f"Final Accuracy: {acc:.2f}%")

            mean_acc = np.mean(trial_accs)
            std_acc = np.std(trial_accs)

            results.append({
                "Model": model,
                "Hidden Dim": h_dim,
                "Mean Accuracy": f"{mean_acc:.2f}%",
                "Std Dev": f"{std_acc:.2f}%"
            })

    df = pd.DataFrame(results)
    df.to_csv("mnist_attention_benchmark.csv", index=False)

    print("\n" + "=" * 40)
    print("FINAL MNIST ATTENTION BENCHMARK")
    print("=" * 40)
    print(df.to_string(index=False))


if __name__ == "__main__":
    main()
