import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import math

from test_core import ModelWrapper, train, test

def run_experiment(model_type, hidden_dim, epochs=14, device="cpu"):
    print(f"\n>>> Starting Experiment: Model={model_type}, HiddenDim={hidden_dim}")
    
    params = {
        "mode": model_type,
        "input_dim": 784,
        "hidden_dim": hidden_dim,
        "num_classes": 10,
        "beta": 1/math.sqrt(hidden_dim),
        "num_states": 1,
        "num_memories": 64
    }

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_set = datasets.MNIST('./datasets', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./datasets', train=False, transform=transform)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)

    model = ModelWrapper(**params).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

    acc_history = []

    class Args:
        log_interval = 100
        dry_run = False
    args = Args()

    for epoch in range(1, epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        
        model.eval()
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
        
        accuracy = 100. * correct / len(test_loader.dataset)
        acc_history.append(accuracy)
        scheduler.step()
        print(f"Epoch {epoch} Accuracy: {accuracy:.2f}%")

    return acc_history

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # models = ['kf_attention', 'kf_pooling', 'kf_layer', 'hf_attention', 'hf_pooling', 'hf_layer']
    models = ['kf_attention','hf_attention']
    hidden_dims = [8, 32]
    epochs = 20 
    
    results = {}

    for h_dim in hidden_dims:
        results[h_dim] = {}
        for m_type in models:
            results[h_dim][m_type] = run_experiment(m_type, h_dim, epochs, device)

    num_plots = len(hidden_dims)
    fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5), squeeze=False)

    for i, h_dim in enumerate(hidden_dims):
        ax = axes[0, i]
        for m_type in models:
            ax.plot(range(1, epochs + 1), results[h_dim][m_type], 
                    label=m_type, marker='o', markersize=4)
        
        ax.set_title(f'MNIST Accuracy (d={h_dim})')
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Accuracy (%)')
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.legend(fontsize='small')
        
        ax.set_ylim(min(70, min([min(v) for v in results[h_dim].values()])), 100)

    plt.tight_layout()
    dims_str = '_'.join(map(str, hidden_dims))
    plt.savefig(f'{dims_str}_mnist_dim_comparison1.png')
    plt.show()

if __name__ == "__main__":
    main()