"""
This is the main script to run the LoRA Finetuning on the CIFAR datasets.
"""

import os
import argparse
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import timm
from tqdm import tqdm
import matplotlib.pyplot as plt

from lora_vit import LoRA_ViT_timm


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


###########################################################
# Let's define the train and eval function
###########################################################
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

###########################################################
# We define the Attack function
###########################################################
def pgd_attack(model, images, labels, criterion, device, epsilon, alpha, iters):
    """
        This is an implementation of the Proximal Gradient Descent Attack.
        ---
            - epsilon: the attack budget
            - iters: the number of iterations.
    """
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    adv_images = images.clone().detach()
    adv_images.requires_grad = True
    for _ in range(iters):
        outputs = model(adv_images)
        loss = criterion(outputs, labels)
        model.zero_grad()
        loss.backward()
        grad = adv_images.grad.data
        adv_images = adv_images + alpha * grad.sign()
        perturbation = torch.clamp(adv_images - images, min=-epsilon, max=epsilon)
        adv_images = torch.clamp(images + perturbation, 0, 1).detach_()
        adv_images.requires_grad = True
    return adv_images

def fgsm_attack(model, images, labels, criterion, device, epsilon):
    """
        This is an implementation of the Fast Gradient Sign Attack.
        ---
            - epsilon: the attack budget
    """
    images = images.clone().detach().to(device)
    images.requires_grad = True
    labels = labels.clone().detach().to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)
    model.zero_grad()
    loss.backward()
    grad = images.grad.data
    adv_images = torch.clamp(images + epsilon * grad.sign(), 0, 1)
    return adv_images

def evaluate_attack(model, loader, criterion, device, attack_type, epsilon,
                                                            alpha=0.0, iters=0):
    """
        This is a function that calls the different attack and returns the
        corresponding attacked accuracy.
        ---
            - attack_type: the attack to be performed (pgd or fgsm)
    """
    model.eval()
    total_correct = 0
    correct_after = 0
    total_images = len(loader.dataset)
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            preds = model(images).argmax(1)

        # Note that we attack only the positively classified images.
        mask = preds.eq(labels)
        n_corr = mask.sum().item()
        total_correct += n_corr
        if n_corr > 0:
            imgs_corr = images[mask]; labs_corr = labels[mask]
            if attack_type == 'pgd':
                adv = pgd_attack(model, imgs_corr, labs_corr, criterion,
                                                device, epsilon, alpha, iters)
            else:
                adv = fgsm_attack(model, imgs_corr, labs_corr, criterion,
                                                            device, epsilon)
            with torch.no_grad():
                correct_after += model(adv).argmax(1).eq(labs_corr).sum().item()
    attacked_acc_full = 100. * correct_after / total_images
    success_rate = 100. * (total_correct - correct_after) / total_correct if total_correct > 0 else 0.0
    return attacked_acc_full, success_rate


#########################################
# This is our main script
#########################################
if __name__ == "__main__":
    # Parse Arguments
    parser = argparse.ArgumentParser(description="Fine-tune ViT with different pooling methods")
    parser.add_argument("--dataset", type=str,
            choices=["CIFAR10", "CIFAR100"], default="CIFAR10",
                                                    help="Dataset to use")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--model_type", type=str, default="large",
                        choices=["base", "large"],
                        help="Type of ViT model to use")

    args = parser.parse_args()

    # As mentionned in the paper, we perform 3 seeds and report average
    seeds = [42, 43, 44]
    runs = []

    # Hyperparameters
    BATCH_SIZE = 32
    EPOCHS     = 5
    LR         = 1e-3
    DEVICE     = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Define here the path to your data (otherwise it will directly be downloaed)
    _PATH_= "./data/"
    dataset_name = args.dataset

    # Dataset-specific settings.
    if args.dataset == "CIFAR10":
        NUM_CLASSES = 10
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        train_dataset = torchvision.datasets.CIFAR10(root=_PATH_,
                        train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=_PATH_,
                        train=False, download=True, transform=transform_test)

    elif args.dataset == "CIFAR100":
        NUM_CLASSES = 100
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        train_dataset = torchvision.datasets.CIFAR100(root=_PATH_, train=True,
                                        download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100(root=_PATH_, train=False,
                                        download=True, transform=transform_test)
    else:
        raise ValueError("Unsupported dataset selected.")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


    # This is the set of epsilons we use in the paper
    # note that for the PGD we only use 4/255
    epsilons = [4/255, 6/255, 8/255, 12/255, 16/255, 20/255, 24/255, 28/255, 32/255]

    # The parameters for PGD
    alpha = 2/255
    pgd_iters = 10

    for seed in seeds:
        print(f"\n===== Run with seed {seed} =====")
        set_seed(seed)

        # Model setup
        if args.model_type == "base":
            base_model = timm.create_model(
                model_name   = "vit_base_patch16_224",
                pretrained   = True,
                num_classes  = NUM_CLASSES,
                global_pool  = None
            ).to(DEVICE)
        else:
            base_model = timm.create_model(
                model_name   = "vit_large_patch16_224",
                pretrained   = True,
                num_classes  = NUM_CLASSES,
                global_pool  = None
            ).to(DEVICE)

        for param in base_model.parameters():
            param.requires_grad = False

        # For the results on the alpha value, please adapt the value here.
        model = LoRA_ViT_timm(vit_model=base_model,
                              r=4, alpha=4,
                              num_classes=NUM_CLASSES).to(DEVICE)

        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
        criterion = nn.CrossEntropyLoss()

        # Training
        for epoch in range(EPOCHS):
            print(f"Epoch {epoch+1}/{EPOCHS}")
            train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
            val_loss, val_acc     = evaluate(model, test_loader, criterion, DEVICE)
            print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Final clean accuracy
        _, clean_acc = evaluate(model, test_loader, criterion, DEVICE)
        print(f"Clean Accuracy: {clean_acc:.2f}%")

        # Adversarial evaluation
        results = []
        for eps in epsilons:
            alpha = eps/4
            # In the paper, for the PGD we only do \epsilon = 4/255
            if eps == 4/255:
                pgd_acc, pgd_success = evaluate_attack(
                    model, test_loader, criterion, DEVICE,
                    attack_type="pgd", epsilon=eps, alpha=alpha, iters=pgd_iters
                )
                print(f"PGD  (ε={eps:.3f})  •  Adversarial Acc: {pgd_acc:.2f}%  •  Success Rate: {pgd_success:.2f}%")
                results.append({'seed': seed, 'attack': 'PGD', 'budget': eps,
                                'clean_acc': clean_acc, 'attacked_acc': pgd_acc,
                                'success_rate': pgd_success})

            fgsm_acc, fgsm_succ = evaluate_attack(
                model, test_loader, criterion, DEVICE,
                attack_type="fgsm", epsilon=eps
            )
            print(f"FGSM (ε={eps:.3f})  •  Adversarial Acc: {fgsm_acc:.2f}%  •  Success Rate: {fgsm_succ:.2f}%")
            results.append({'seed': seed, 'attack': 'FGSM', 'budget': eps,
                            'clean_acc': clean_acc, 'attacked_acc': fgsm_acc,
                            'success_rate': fgsm_succ})

        runs.append(pd.DataFrame(results))

    # Aggregate results
    all_df = pd.concat(runs, ignore_index=True)
    summary = all_df.groupby(['attack', 'budget'])[['clean_acc', 'attacked_acc', 'success_rate']]
    stats_mean = summary.mean().add_suffix('_mean')
    stats_std  = summary.std().add_suffix('_std')
    summary_df = pd.concat([stats_mean, stats_std], axis=1).reset_index()

    print("\n===== Summary across seeds =====")
    print(summary_df)

    # The results will be saved to a "Results" Folder
    os.makedirs('Results',exist_ok=True)
    all_df.to_csv(f'Results/{args.model_type}_lora_attack_all_runs_{dataset_name}.csv',index=False)
    summary_df.to_csv(f'Results/{args.model_type}_lora_attack_summary_{dataset_name}.csv',index=False)
    print("Saved detailed and summary CSVs with dataset tag.")
