"""
This is the main script to run the Head-Only Finetuning on the CIFAR datasets.
"""

import os
import argparse
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import timm
from tqdm import tqdm
import matplotlib.pyplot as plt

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

###########################################################
# We define the model here -- note that Head-Only is added.
###########################################################
class ViTWithLinearHead(nn.Module):
    def __init__(self, base_model, num_classes):
        super().__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(base_model.embed_dim, num_classes)

    def forward(self, x):
        tokens = self.base_model.forward_features(x)
        pooled = tokens[:, 0, :]
        out = self.classifier(pooled)
        return out

###########################################################
# Let's define the train and eval function
###########################################################
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

###########################################################
# We define the Attack function
###########################################################
def pgd_attack(model, images, labels, criterion, device, epsilon, alpha, iters):
    """
        This is an implementation of the Proximal Gradient Descent Attack.
        ---
            - epsilon: the attack budget
            - iters: the number of iterations.
    """
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    adv_images = images.clone().detach()
    adv_images.requires_grad = True
    for _ in range(iters):
        outputs = model(adv_images)
        loss = criterion(outputs, labels)
        model.zero_grad()
        loss.backward()
        grad = adv_images.grad.data
        adv_images = adv_images + alpha * grad.sign()
        perturbation = torch.clamp(adv_images - images, min=-epsilon, max=epsilon)
        adv_images = torch.clamp(images + perturbation, 0, 1).detach_()
        adv_images.requires_grad = True
    return adv_images

def fgsm_attack(model, images, labels, criterion, device, epsilon):
    """
        This is an implementation of the Fast Gradient Sign Attack.
        ---
            - epsilon: the attack budget
    """
    images = images.clone().detach().to(device)
    images.requires_grad = True
    labels = labels.clone().detach().to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)
    model.zero_grad()
    loss.backward()
    grad = images.grad.data
    adv_images = torch.clamp(images + epsilon * grad.sign(), 0, 1)
    return adv_images

def evaluate_attack(model, loader, criterion, device, attack_type, epsilon,
                                                            alpha=0.0, iters=0):
    """
        This is a function that calls the different attack and returns the
        corresponding attacked accuracy.
        ---
            - attack_type: the attack to be performed (pgd or fgsm)
    """
    model.eval()
    total_correct = 0
    correct_after = 0
    total_images = len(loader.dataset)
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            preds = model(images).argmax(1)

        # Note that we attack only the positively classified images.
        mask = preds.eq(labels)
        n_corr = mask.sum().item()
        total_correct += n_corr
        if n_corr > 0:
            imgs_corr = images[mask]; labs_corr = labels[mask]
            if attack_type == 'pgd':
                adv = pgd_attack(model, imgs_corr, labs_corr, criterion,
                                                device, epsilon, alpha, iters)
            else:
                adv = fgsm_attack(model, imgs_corr, labs_corr, criterion,
                                                            device, epsilon)
            with torch.no_grad():
                correct_after += model(adv).argmax(1).eq(labs_corr).sum().item()
    attacked_acc_full = 100. * correct_after / total_images
    success_rate = 100. * (total_correct - correct_after) / total_correct if total_correct > 0 else 0.0
    return attacked_acc_full, success_rate

#########################################
# This is our main script
#########################################
if __name__ == "__main__":
    # Parse Arguments
    parser = argparse.ArgumentParser(description="Fine-tune ViT with different pooling methods")
    parser.add_argument("--dataset", type=str,
            choices=["CIFAR10", "CIFAR100"], default="CIFAR10",
                                                    help="Dataset to use")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--model_type", type=str, default="large",
                        choices=["base", "large"],
                        help="Type of ViT model to use")

    args = parser.parse_args()

    # As mentionned in the paper, we perform 3 seeds and report average
    seeds = [42, 43, 44]
    runs = []

    # Hyperparameters
    BATCH_SIZE = 32
    EPOCHS     = 5
    LR         = 1e-3
    DEVICE     = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Define here the path to your data (otherwise it will directly be downloaed)
    _PATH_= "./data/"
    dataset_name = args.dataset

    # Dataset-specific settings.
    if args.dataset == "CIFAR10":
        NUM_CLASSES = 10
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        train_dataset = torchvision.datasets.CIFAR10(root=_PATH_,
                        train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=_PATH_,
                        train=False, download=True, transform=transform_test)

    elif args.dataset == "CIFAR100":
        NUM_CLASSES = 100
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        train_dataset = torchvision.datasets.CIFAR100(root=_PATH_, train=True,
                                        download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100(root=_PATH_, train=False,
                                        download=True, transform=transform_test)
    else:
        raise ValueError("Unsupported dataset selected.")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                                shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                                shuffle=False, num_workers=2)

    # This is the set of epsilons we use in the paper
    # note that for the PGD we only use 4/255
    epsilons = [4/255, 6/255, 8/255, 12/255, 16/255, 20/255, 24/255, 28/255, 32/255]

    # The parameters for PGD
    alpha = 2/255
    pgd_iters = 10

    for seed in seeds:
        print(f"\n===== Run with seed {seed} =====")
        set_seed(seed)

        # Model setup
        if args.model_type == "base":
            base_model = timm.create_model(
                model_name   = "vit_base_patch16_224",
                pretrained   = True,
                num_classes  = NUM_CLASSES,
                global_pool  = None
            ).to(DEVICE)
        else:
            base_model = timm.create_model(
                model_name   = "vit_large_patch16_224",
                pretrained   = True,
                num_classes  = NUM_CLASSES,
                global_pool  = None
            ).to(DEVICE)

        for param in base_model.parameters():
            param.requires_grad = False
        model = ViTWithLinearHead(base_model, NUM_CLASSES).to(DEVICE)

        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
        criterion = nn.CrossEntropyLoss()

        # Training
        for epoch in range(EPOCHS):
            print(f"Epoch {epoch+1}/{EPOCHS}")
            train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
            val_loss, val_acc     = evaluate(model, test_loader, criterion, DEVICE)
            print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Final clean accuracy
        _, clean_acc = evaluate(model, test_loader, criterion, DEVICE)
        print(f"Clean Accuracy: {clean_acc:.2f}%")

        # Adversarial evaluation
        results = []
        for eps in epsilons:
            alpha = eps/4
            # In the paper, for the PGD we only do \epsilon = 4/255
            if eps == 4/255:
                pgd_acc, pgd_success = evaluate_attack(
                    model, test_loader, criterion, DEVICE,
                    attack_type="pgd", epsilon=eps, alpha=alpha, iters=pgd_iters
                )
                print(f"PGD  (ε={eps:.3f})  •  Adversarial Acc: {pgd_acc:.2f}%  •  Success Rate: {pgd_success:.2f}%")
                results.append({'seed': seed, 'attack': 'PGD', 'budget': eps,
                                'clean_acc': clean_acc, 'attacked_acc': pgd_acc,
                                'success_rate': pgd_success})

            fgsm_acc, fgsm_succ = evaluate_attack(
                model, test_loader, criterion, DEVICE,
                attack_type="fgsm", epsilon=eps
            )
            print(f"FGSM (ε={eps:.3f})  •  Adversarial Acc: {fgsm_acc:.2f}%  •  Success Rate: {fgsm_succ:.2f}%")
            results.append({'seed': seed, 'attack': 'FGSM', 'budget': eps,
                            'clean_acc': clean_acc, 'attacked_acc': fgsm_acc,
                            'success_rate': fgsm_succ})

        runs.append(pd.DataFrame(results))

    # Aggregate results
    all_df = pd.concat(runs, ignore_index=True)
    summary = all_df.groupby(['attack', 'budget'])[['clean_acc', 'attacked_acc', 'success_rate']]
    stats_mean = summary.mean().add_suffix('_mean')
    stats_std  = summary.std().add_suffix('_std')
    summary_df = pd.concat([stats_mean, stats_std], axis=1).reset_index()

    print("\n===== Summary across seeds =====")
    print(summary_df)

    # The results will be saved to a "Results" Folder
    os.makedirs('Results',exist_ok=True)
    all_df.to_csv(f'Results/{args.model_type}_vit_attack_all_runs_{dataset_name}.csv',index=False)
    summary_df.to_csv(f'Results/{args.model_type}_vit_attack_summary_{dataset_name}.csv',index=False)
    print("Saved detailed and summary CSVs with dataset tag.")
