import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.utils.data import DataLoader, SubsetRandomSampler

import torchvision
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
import os
import csv

# Local imports
from utils import set_seed, create_adversarial_dataset

# Argument parser
parser = argparse.ArgumentParser(description='Adversarial Training Evaluation')
parser.add_argument('--dataset', type=str, choices=['cifar10'], required=True)
parser.add_argument('--model', type=str, choices=['vgg16', 'resnet18'], required=True)
parser.add_argument('--surrogate_model', type=str, default='')
parser.add_argument('--savedir', type=str, required=True)
parser.add_argument('--delta_path', type=str, required=True)
parser.add_argument('--delta_suffix', type=str, required=True)
args = parser.parse_args()

if not os.path.exists(args.savedir):
    os.makedirs(args.savedir, exist_ok=True)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])


# Dataset loading
train_dataset = torchvision.datasets.CIFAR10(root='data', train=True,
                                             download=False, transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(root='data', train=False,
                                            download=False, transform=test_transform)
num_classes = 10

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Model setup
if args.model == 'vgg16':
    model = models.vgg16(pretrained=False)
    model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
elif args.model == 'resnet18':
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)

if len(args.surrogate_model) > 0: 
    surrogate_model = models.resnet18(pretrained=False)
    surrogate_model.fc = nn.Linear(surrogate_model.fc.in_features, num_classes) 
    surrogate_dir = args.savedir.replace("vgg16", args.surrogate_model) 
    surrogate_model.load_state_dict(torch.load(f"{surrogate_dir}/{args.surrogate_model}_final.pth"))
    surrogate_model = surrogate_model.to(device)

model.load_state_dict(torch.load(f"{args.savedir}/{args.model}_final.pth"))
model = model.to(device)

delta = torch.load(args.delta_path)
delta = delta.to(device)

def compute_gradients(model, images, labels):
    model.eval()
    images, labels = images.to(device), labels.to(device)

    # Ensure gradients from previous iterations are cleared
    model.zero_grad()
    
    outputs = model(images)
    loss = nn.CrossEntropyLoss()(outputs, labels)
    loss.backward()
    # Extract gradients from model parameters
    gradients = []
    for param in model.parameters():
        if param.grad is not None:
            gradients.append(param.grad.view(-1))

    # Concatenate all gradients into a single vector
    gradients = torch.cat(gradients).detach()
    model.zero_grad()
    torch.cuda.empty_cache()
    
    return gradients.cpu()

def compute_classifier_gradients(model, images, labels):
    model.eval()
    images, labels = images.to(device), labels.to(device)

    model.zero_grad()

    outputs = model(images)
    loss = nn.CrossEntropyLoss()(outputs, labels)
    loss.backward()

    # Determine the classifier module
    if hasattr(model, 'classifier'):
        classifier = model.classifier
    elif hasattr(model, 'fc'):
        classifier = model.fc
    else:
        raise ValueError("Model must have either 'classifier' or 'fc' attribute.")

    # Get the final layer (handle Sequential vs single Linear layer)
    if isinstance(classifier, nn.Sequential):
        final_layer = classifier[-1]
    else:
        final_layer = classifier

    # Collect gradients from final layer parameters
    gradients = []
    for param in final_layer.parameters():
        if param.grad is not None:
            gradients.append(param.grad.view(-1))

    if not gradients:
        raise RuntimeError("No gradients found in final classifier layer.")

    gradients = torch.cat(gradients).detach()
    model.zero_grad()
    torch.cuda.empty_cache()

    return gradients.cpu()

def match_adversarial_batches(adv_images, adv_labels, train_loader, surrogate_model, num_batches, full_grad=True):
    print("Matching gradients!")
    matched_batches = []
    l2_distances = []

    batch_size = train_loader.batch_size
    total_samples = batch_size * num_batches
    assert total_samples <= len(adv_images), "Not enough adversarial examples available!"

    # Sample the adversarial examples to use
    sampled_indices = np.random.choice(len(adv_images), total_samples, replace=False)
    adv_images_sampled = adv_images[sampled_indices].to(device)
    adv_labels_sampled = adv_labels[sampled_indices].to(device)

    all_indices = np.arange(len(train_loader.dataset))

    for i in tqdm(range(0, total_samples, batch_size)):
        batch_images = adv_images_sampled[i:i+batch_size]
        batch_labels = adv_labels_sampled[i:i+batch_size]

        # Compute adversarial gradient
        if full_grad:
            adv_grad = compute_gradients(surrogate_model, batch_images, batch_labels)
        else:
            # Only final layer gradients 
            adv_grad = compute_classifier_gradients(surrogate_model, batch_images, batch_labels)

        # Randomly sample 300 batch indices from train_loader
        sampled_batch_indices = np.random.choice(len(train_loader), 300, replace=False)
        sampled_train_indices = []
        for batch_idx in sampled_batch_indices:
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(train_loader.dataset))
            sampled_train_indices.extend(all_indices[start_idx:end_idx])

        sampled_train_loader = DataLoader(
            train_loader.dataset,
            batch_size=batch_size,
            sampler=SubsetRandomSampler(sampled_train_indices)
        )

        # Match adversarial gradient with best training batch gradient
        min_dist = float('inf')
        best_batch = None

        for images, labels in sampled_train_loader:
            images, labels = images.to(device), labels.to(device)

            if full_grad:
                nat_grad = compute_gradients(surrogate_model, images, labels)
            else:
                # Only final layer 
                nat_grad = compute_classifier_gradients(surrogate_model, images, labels)

            dist = torch.norm(nat_grad - adv_grad, p=2).item()
            if dist < min_dist:
                min_dist = dist
                best_batch = (images.cpu(), labels.cpu())

            # Delete natural gradient after comparison
            del nat_grad
            torch.cuda.empty_cache()

        matched_batches.append(best_batch)
        l2_distances.append(min_dist)

        # Clean up
        del adv_grad, batch_images, batch_labels
        torch.cuda.empty_cache()
        gc.collect()

    print("Mean L2 Norm: {}".format(np.mean(l2_distances)))
    return matched_batches

def train_adv(model, criterion, optimizer, device,
              blackbox=False, surrogate_model=None, surrogate_optimizer=None, 
              adv_batches=90, eval_interval=10, delta = None, full_grad=True, target_class = 0):
    
    adv_images, adv_labels = create_adversarial_dataset(train_dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=1.)
    
    if blackbox:
        matched_batches = match_adversarial_batches(adv_images, adv_labels, train_loader, surrogate_model, num_batches=adv_batches, full_grad=full_grad)
    else: 
        matched_batches = match_adversarial_batches(adv_images, adv_labels, train_loader, model, num_batches=adv_batches, full_grad=full_grad)
    
    model.train()
    if blackbox:
        surrogate_model.train()

    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

    for images, labels in matched_batches:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
    
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if blackbox:
            surrogate_optimizer.zero_grad()
            surrogate_outputs = surrogate_model(images)
            surrogate_loss = criterion(surrogate_outputs, labels)
            surrogate_loss.backward()
            surrogate_optimizer.step()

def evaluate(model, test_loader, criterion, device, target_class, dataset_name="Test"):
    model.eval()
    benign_correct = 0
    total = 0
    test_loss = 0.0
    
    # Evaluate on benign dataset
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            benign_correct += predicted.eq(labels).sum().item()
    
    benign_accuracy = 100 * benign_correct / total
    print(f"{dataset_name} Benign Loss: {test_loss/len(test_loader):.4f}, {dataset_name} Benign Accuracy: {benign_accuracy:.2f}%")
    
    # Evaluate attack success rate
    adv_images, adv_labels = create_adversarial_dataset(test_dataset, delta=delta.detach().cpu())
    adv_loader = DataLoader(list(zip(adv_images, adv_labels)), batch_size=32, shuffle=False, num_workers=4)
    attack_success = 0
    total_adv = 0
    
    with torch.no_grad():
        for images, labels in adv_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total_adv += labels.size(0)
            attack_success += (predicted == target_class).sum().item()  # Count instances predicted as adversarial label 0
    
    attack_success_rate = 100 * attack_success / total_adv
    print(f"{dataset_name} Attack Success Rate: {attack_success_rate:.2f}%")
    
    return benign_accuracy, attack_success_rate


# Training 
set_seed(0)

target_class = 0

optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

clean_benign_accuracy, clean_attack_success_rate = evaluate(model, test_loader, criterion, device, target_class, dataset_name="Test")

if len(args.surrogate_model) > 0:
    print("Blackbox!")
    surrogate_optimizer = optim.SGD(surrogate_model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
    train_adv(model, criterion, optimizer, device, blackbox=True, adv_batches=90, delta=delta, full_grad=False, 
             surrogate_model=surrogate_model, surrogate_optimizer= surrogate_optimizer, target_class = target_class) 
else: 
    print("Whitebox!")
    train_adv(model, criterion, optimizer, device, blackbox=False, adv_batches=90, delta=delta, full_grad=False, target_class = target_class)
    
benign_accuracy, attack_success_rate = evaluate(model, test_loader, criterion, device, target_class, dataset_name="Test")

# Save trained model
torch.save(model, f"{args.savedir}/{args.model}_random90_{args.delta_suffix}.pth")

# Save evaluation results to CSV
output_csv = f"{args.savedir}/eval_random90_{args.delta_suffix}.csv"
with open(output_csv, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['Stage', 'Model', 'Dataset', 'Benign Accuracy (%)', 'Attack Success Rate (%)'])
    writer.writerow(['Before Training', args.model, args.dataset, f"{clean_benign_accuracy:.2f}", f"{clean_attack_success_rate:.2f}"])
    writer.writerow(['After Training', args.model, args.dataset, f"{benign_accuracy:.2f}", f"{attack_success_rate:.2f}"])