"""
    This is the official script to reproduce the MLP Results
"""

import copy
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from models import OneLayerMLP, MultiLayerMLP
from attacks import *
from pruning_utils import norm_prune_weights

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training and evaluation loops
def train(model, device, loader, optimizer, scheduler, epoch):
    """
        Train function Loop
    """
    model.train()
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        loss = F.cross_entropy(model(data), target)
        loss.backward()
        optimizer.step()
    if scheduler is not None:
        scheduler.step()
    print(f"Epoch {epoch} complete. LR: {scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']:.2e}")

def test_clean(model, device, loader, desc=""):
    """
        Evaluation function Loop
    """
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            pred = model(data).argmax(dim=1)
            correct += (pred == target).sum().item()
            total   += target.size(0)
    acc = 100. * correct / total
    print(f"{desc} Clean Accuracy: {acc:.2f}%")
    return acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Adversarial MLP on MNIST/CIFAR-10")
    parser.add_argument("--dataset", type=str, choices=["mnist","fashion_mnist","cifar10"], default="mnist")
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--test-batch-size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--weight-decay", type=float, default=5e-4)
    parser.add_argument("--eps", type=float, default=0.1)
    parser.add_argument("--alpha", type=float, default=0.01)
    parser.add_argument("--iters", type=int, default=40)
    parser.add_argument("--output-dir", type=str, default=".")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # Dataset and model setup
    if args.dataset in ["mnist"]:
        transform = transforms.ToTensor()
        ds_map = {"mnist": datasets.MNIST, "fashion_mnist": datasets.FashionMNIST}
        train_ds = ds_map[args.dataset]('./data', train=True, download=True, transform=transform)
        test_ds  = ds_map[args.dataset]('./data', train=False, download=True, transform=transform)
        input_dim = 28*28; model = OneLayerMLP(input_dim, 10).to(device)
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        scheduler = None
    else:
        # In the case of CIFAR-10, we use a deeper model to get a good first
        # Clean accuracy as explained in the paper.
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.2,0.2,0.2,0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
        ])
        train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
        test_ds  = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)
        input_dim = 32*32*3
        hidden_dims = [8192,4096,2048,1024]
        model = MultiLayerMLP(input_dim, hidden_dims, 10, dropout=0.6).to(device)
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

    # We used the classical values of transform from the literature for each data.
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4)
    test_loader  = DataLoader(test_ds, batch_size=args.test_batch_size, shuffle=False, num_workers=4)

    best_acc = 0.0
    best_model = None
    for epoch in range(1, args.epochs+1):
        train(model, device, train_loader, optimizer, scheduler, epoch)
        acc = test_clean(model, device, test_loader, desc=f"Epoch {epoch}")
        if acc > best_acc:
            best_acc = acc
            best_model = copy.deepcopy(model)
    print(f"Best clean accuracy: {best_acc:.2f}% — using this model for attacks/pruning")

    # Use best checkpoint
    model = best_model

    # Baseline metrics
    baseline_clean = test_clean(model, device, test_loader, desc="Baseline")

    # Attacks and pruning
    attacks = [("FGSM", fgsm_attack), ("PGD", pgd_attack)]
    p_values = [0.0] + [0.01 * i for i in range(1, 100)]

    # Apply each attack using the considered pruning values.
    for name, atk_fn in attacks:
        print(f"\n=== {name} attack ===")
        results = []
        for p in p_values:
            if p == 0.0:
                baseline_succ = test_transfer(model, model, device, test_loader,
                                   args.eps, atk_fn, args.alpha, args.iters)
                baseline_att = 100. - baseline_succ
                clean_acc = baseline_clean; attacked_acc = baseline_att
            else:
                pruned = norm_prune_weights(copy.deepcopy(model), p=p)
                clean_acc = test_clean(pruned, device, test_loader, desc=f"p={p:.2f}")
                succ = test_transfer(model, pruned, device, test_loader,
                                     args.eps, atk_fn, args.alpha, args.iters)

                attacked_acc = 100. - succ
            print(f"p={p:.2f} → Clean: {clean_acc:.2f}% | Attacked: {attacked_acc:.2f}%")
            results.append((1-p, clean_acc, attacked_acc))

        ps, cleans, atks = zip(*results)
        plt.figure(figsize=(8, 5))
        plt.plot(ps, cleans, marker='o', label="Clean Pruned Accuracy")
        plt.plot(ps, atks, marker='^', label="Attacked Accuracy")
        plt.xlabel("Pruning Probability (p)", fontsize=14)
        plt.ylabel("Accuracy (%)", fontsize=14)
        plt.title(f"MLP - {args.dataset.upper()} - {name}", fontsize=20)
        plt.legend(); plt.grid(True); plt.tight_layout()
        fname = f"{args.output_dir}/MLP_{args.dataset}_{name}.pdf"
        plt.savefig(fname, bbox_inches="tight")
        print(f"Saved plot: {fname}")
        plt.close()
