"""
    This is the implementation to reproduce the experience on the different sizes
    of MLP.
"""
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

from attacks import *
from pruning_utils import norm_prune_weights
from models import SmallMLP, MediumMLP, LargeMLP

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, device, loader, optimizer, epoch):
    """
        Train function Loop
    """
    model.train()
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        loss = F.cross_entropy(model(data), target)
        loss.backward()
        optimizer.step()
        if batch_idx % 200 == 0:
            print(f"Epoch {epoch} [{batch_idx*len(data)}/{len(loader.dataset)}] Loss: {loss.item():.6f}")

def test_clean(model, device, loader):
    """
        Evaluation function Loop
    """
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            pred = model(data).argmax(dim=1)
            correct += (pred == target).sum().item()
            total   += target.size(0)
    return 100. * correct / total


# --- Main experiment ---
def main():
    # settings
    batch_size, test_batch = 64, 256
    epochs, lr, mom, eps    = 5, 0.01, 0.9, 0.1
    p_values = [0.0] + [0.01*i for i in range(1,100)]
    model_defs = [
        ("Small-MLP",  SmallMLP),
        ("Medium-MLP", MediumMLP),
        ("Large-MLP",  LargeMLP)
    ]

    # color palettes & markers, we use them for plots
    clean_colors    = ['#a6bddb', '#3690c0', '#045a8d']
    attacked_colors = ['#fdd0a2', '#f16913', '#8c2d04']
    markers = ['o', 's', '^']
    linestyles = {'clean':'-', 'attacked':'--'}
    mark_every = 5

    # data loaders
    transform = transforms.ToTensor()
    train_loader = DataLoader(
        datasets.MNIST('./data', train=True,  download=True, transform=transform),
        batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(
        datasets.MNIST('./data', train=False, download=True, transform=transform),
        batch_size=test_batch, shuffle=False)

    fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True, gridspec_kw={'wspace': 0.05})
    for ax, (attack_name, attack_fn) in zip(axes, [("PGD", pgd_attack), ("FGSM", fgsm_attack)]):
        all_results = {}

        for idx, (name, ModelClass) in enumerate(model_defs):
            print(f"\n=== {name} with {attack_name} ===")
            net = ModelClass().to(device)
            optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom)

            # train
            for epoch in range(1, epochs+1):
                train(net, device, train_loader, optimizer, epoch)

            # baseline
            baseline_clean = test_clean(net, device, test_loader)
            succ = test_transfer(net, net, device, test_loader, eps, attack_fn, alpha=0.01, iters=40)
            baseline_att = 100. - succ
            print(f"{name} baseline clean {baseline_clean:.2f}%, attacked {baseline_att:.2f}%")

            results = []
            for p in p_values:
                if p == 0.0:
                    ca, aa = baseline_clean, baseline_att
                else:
                    pruned = norm_prune_weights(copy.deepcopy(net), p=p)
                    ca = test_clean(pruned, device, test_loader)
                    succ = test_transfer(net, pruned, device, test_loader, eps, attack_fn, alpha=0.01, iters=40)
                    aa = 100. - succ
                retention = 1 - p
                results.append((retention, ca, aa))
            all_results[name] = results

        # plot each model
        for idx, (name, results) in enumerate(all_results.items()):
            rs, cleans, atks = zip(*results)
            clr_c = clean_colors[idx]
            clr_a = attacked_colors[idx]
            mker  = markers[idx]

            # clean
            ax.plot(rs, cleans,
                    color=clr_c, marker=mker, markevery=mark_every,
                    linestyle=linestyles['clean'],
                    label=f"{name} Clean")
            # attacked
            ax.plot(rs, atks,
                    color=clr_a, marker=mker, markevery=mark_every,
                    linestyle=linestyles['attacked'],
                    label=f"{name} Attacked")

        ax.set_title(f"{attack_name} Attack", fontsize=16)
        ax.set_xlabel("Pruning Probability (p)", fontsize=14)
        if ax is axes[0]:
            ax.set_ylabel("Accuracy (%)", fontsize=16)
            ax.legend(loc='upper left', frameon=False, fontsize = 13)

        ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)


    os.makedirs("plots", exist_ok=True)
    outname = "./MLP_comparison_PGD_FGSM.pdf"
    plt.savefig(outname, bbox_inches='tight', dpi=300)
    print(f"\nSaved combined subplot → {outname}")
    plt.show()

if __name__ == "__main__":
    main()
