import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils.data import Subset
from torch.autograd import grad

import torchvision
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
import os
import csv
import pandas as pd

# Local imports
from utils import set_seed, create_adversarial_dataset
from derm_datasets import ISICDataset

# Argument parser
parser = argparse.ArgumentParser(description='Adversarial Training Evaluation')
parser.add_argument('--split', type=int, required=True)
parser.add_argument('--blackbox', default=False, action='store_true')
parser.add_argument('--savedir', type=str, required=True)
parser.add_argument('--delta_path', type=str, required=True)
parser.add_argument('--delta_suffix', type=str, required=True)
args = parser.parse_args()

if not os.path.exists(args.savedir):
    os.makedirs(args.savedir, exist_ok=True)

def flatten_gradients(grads):
    return torch.cat([g.contiguous().view(-1) for g in grads])

def flatten_final_layer_gradients(model, loss, create_graph=False):
    # Determine the final layer module (fc for ResNet, classifier for DenseNet)
    if hasattr(model, 'fc'):
        last_layer = model.fc
    elif hasattr(model, 'classifier'):
        last_layer = model.classifier
    else:
        raise ValueError("Model architecture not supported. Expected 'fc' or 'classifier' attribute.")
    
    # Extract first linear layer inside the final head (if sequential)
    if isinstance(last_layer, nn.Sequential):
        target_layer = list(last_layer.children())[0]
    else:
        target_layer = last_layer  # if it's a single linear layer
    
    grads = grad(loss, target_layer.parameters(), retain_graph=True, create_graph=create_graph)
    return flatten_gradients(grads)

def compute_gradients(model, images, labels, create_graph=False):
    model.eval()
    images, labels = images.to(device), labels.to(device)
    model.zero_grad()

    outputs = model(images)
    loss = nn.BCELoss()(outputs, labels.unsqueeze(1).float())  # ensure label shape is [B, 1]
    
    # Compute gradients w.r.t. all model parameters
    grads = grad(loss, model.parameters(), retain_graph=True, create_graph=create_graph)
    flattened = flatten_gradients(grads)

    model.zero_grad()
    torch.cuda.empty_cache()
    return flattened.detach().cpu()

def compute_classifier_gradients(model, images, labels, create_graph=False):
    model.eval()
    images, labels = images.to(device), labels.to(device)
    model.zero_grad()

    outputs = model(images)
    loss = nn.BCELoss()(outputs, labels.unsqueeze(1).float())  # ensure label shape is [B, 1]

    # Determine final layer (fc or classifier)
    if hasattr(model, 'fc'):
        last_layer = model.fc
    elif hasattr(model, 'classifier'):
        last_layer = model.classifier
    else:
        raise ValueError("Model architecture not supported. Expected 'fc' or 'classifier'.")

    if isinstance(last_layer, nn.Sequential):
        target_layer = list(last_layer.children())[0]
    else:
        target_layer = last_layer

    grads = grad(loss, target_layer.parameters(), retain_graph=True, create_graph=create_graph)
    flattened = flatten_gradients(grads)

    model.zero_grad()
    torch.cuda.empty_cache()
    return flattened.detach().cpu()

def match_adversarial_batches(adv_images, adv_labels, train_loader, surrogate_model, num_batches, full_grad=True):
    print("Matching gradients!")
    matched_batches = []
    l2_distances = []

    batch_size = train_loader.batch_size
    total_samples = batch_size * num_batches
    assert total_samples <= len(adv_images), "Not enough adversarial examples available!"

    # Sample the adversarial examples to use
    sampled_indices = np.random.choice(len(adv_images), total_samples, replace=False)
    adv_images_sampled = adv_images[sampled_indices].to(device)
    adv_labels_sampled = adv_labels[sampled_indices].to(device)

    all_indices = np.arange(len(train_loader.dataset))

    for i in tqdm(range(0, total_samples, batch_size)):
        batch_images = adv_images_sampled[i:i+batch_size]
        batch_labels = adv_labels_sampled[i:i+batch_size]

        # Compute adversarial gradient
        if full_grad:
            adv_grad = compute_gradients(surrogate_model, batch_images, batch_labels)
        else:
            # Only final layer gradients 
            adv_grad = compute_classifier_gradients(surrogate_model, batch_images, batch_labels)

        # Randomly sample 300 batch indices from train_loader
        if len(train_loader) < 300: 
            sampled_train_loader = train_loader 
        else:
            sampled_batch_indices = np.random.choice(len(train_loader), 300, replace=False)
            sampled_train_indices = []
            for batch_idx in sampled_batch_indices:
                start_idx = batch_idx * batch_size
                end_idx = min((batch_idx + 1) * batch_size, len(train_loader.dataset))
                sampled_train_indices.extend(all_indices[start_idx:end_idx])
    
            sampled_train_loader = DataLoader(
                train_loader.dataset,
                batch_size=batch_size,
                sampler=SubsetRandomSampler(sampled_train_indices)
            )

        # Match adversarial gradient with best training batch gradient
        min_dist = float('inf')
        best_batch = None

        for images, labels in sampled_train_loader:
            images, labels = images.to(device), labels.to(device)

            if full_grad:
                nat_grad = compute_gradients(surrogate_model, images, labels)
            else:
                # Only final layer 
                nat_grad = compute_classifier_gradients(surrogate_model, images, labels)

            dist = torch.norm(nat_grad - adv_grad, p=2).item()
            if dist < min_dist:
                min_dist = dist
                best_batch = (images.cpu(), labels.cpu())

            # Delete natural gradient after comparison
            del nat_grad
            torch.cuda.empty_cache()

        matched_batches.append(best_batch)
        l2_distances.append(min_dist)

        # Clean up
        del adv_grad, batch_images, batch_labels
        torch.cuda.empty_cache()
        gc.collect()

    mean_l2_norm = np.mean(l2_distances)
    print("Mean L2 Norm: {}".format(mean_l2_norm))
    return matched_batches, mean_l2_norm


def match_adversarial_batches_no_replacement(adv_images, adv_labels, train_loader, surrogate_model, num_batches, full_grad=True):
    print("Matching gradients without replacement!")
    matched_batches = []
    l2_distances = []

    batch_size = train_loader.batch_size
    total_samples = batch_size * num_batches
    assert total_samples <= len(adv_images), "Not enough adversarial examples available!"

    # Sample the adversarial examples to use
    sampled_indices = np.random.choice(len(adv_images), total_samples, replace=False)
    adv_images_sampled = adv_images[sampled_indices].to(device)
    adv_labels_sampled = adv_labels[sampled_indices].to(device)

    # Step 1: Precompute all clean batches and their gradients
    print("Precomputing clean gradients...")
    clean_batches = []
    clean_grads = []

    for images, labels in tqdm(train_loader, desc="Computing clean gradients"):
        images, labels = images.to(device), labels.to(device)
        if full_grad:
            grads = compute_gradients(surrogate_model, images, labels)
        else:
            grads = compute_classifier_gradients(surrogate_model, images, labels)

        clean_batches.append((images.detach().cpu(), labels.detach().cpu()))
        clean_grads.append(grads.cpu())
        
        del grads
        torch.cuda.empty_cache()

    # Step 2: Match each adversarial batch to a unique clean batch
    for i in tqdm(range(0, total_samples, batch_size), desc="Matching batches"):
        batch_images = adv_images_sampled[i:i+batch_size]
        batch_labels = adv_labels_sampled[i:i+batch_size]

        if full_grad:
            adv_grad = compute_gradients(surrogate_model, batch_images, batch_labels)
        else:
            adv_grad = compute_classifier_gradients(surrogate_model, batch_images, batch_labels)

        # Find best match from remaining clean batches
        min_dist = float('inf')
        best_idx = -1

        for j, nat_grad in enumerate(clean_grads):
            dist = torch.norm(nat_grad - adv_grad, p=2).item()
            if dist < min_dist:
                min_dist = dist
                best_idx = j

        matched_batches.append(clean_batches[best_idx])
        l2_distances.append(min_dist)

        # Remove the matched clean batch and gradient (no replacement)
        del clean_batches[best_idx]
        del clean_grads[best_idx]
        del adv_grad, batch_images, batch_labels
        torch.cuda.empty_cache()
        gc.collect()

    mean_l2_norm = np.mean(l2_distances)
    print("Mean L2 Norm (no replacement): {:.4f}".format(mean_l2_norm))
    return matched_batches, mean_l2_norm


def train_adv(model, criterion, optimizer, device,
              blackbox=False, surrogate_model=None, surrogate_optimizer=None, 
              adv_batches=90, eval_interval=10, delta = None, full_grad=True):
    
    adv_images, adv_labels = create_adversarial_dataset(train_dataset, delta=delta.detach().cpu(), y_adv=0, alpha=1.)
    
    if blackbox:
        matched_batches, mean_l2 = match_adversarial_batches(adv_images, adv_labels, train_loader, surrogate_model, num_batches=adv_batches, full_grad=full_grad)
    else: 
        matched_batches, mean_l2 = match_adversarial_batches(adv_images, adv_labels, train_loader, model, num_batches=adv_batches, full_grad=full_grad)

    torch.save(matched_batches, f"{args.savedir}/matched_batches_split{args.split}_{args.delta_suffix}.pt")

    model.train()
    if blackbox:
        surrogate_model.train()
    
    # (Optional) Track model weights before training
    print("Initial weight sample:", model.fc[0].weight.data[0][:5].cpu().numpy())

    for images, labels in matched_batches:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
    
        outputs = model(images)
        loss = criterion(outputs, labels.view(-1, 1).float())
        print(f"Train Loss: {loss.item():.4f}")
        loss.backward()
        optimizer.step()

        if blackbox:
            surrogate_optimizer.zero_grad()
            surrogate_outputs = surrogate_model(images)
            surrogate_loss = criterion(surrogate_outputs, labels.view(-1, 1).float())
            surrogate_loss.backward()
            surrogate_optimizer.step()

    print("Final weight sample:", model.fc[0].weight.data[0][:5].cpu().numpy())
    return mean_l2

def evaluate(model, test_loader, criterion, device, delta, target_class, dataset_name="Test"):
    model.eval()
    benign_correct = 0
    total = 0
    test_loss = 0.0
    
    # Evaluate on benign dataset
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels.view(-1, 1).float())
            test_loss += loss.item()
            # _, predicted = outputs.max(1)
            predicted = (outputs > 0.5).long().view(-1)
            total += labels.size(0)
            benign_correct += predicted.eq(labels).sum().item()
    
    benign_accuracy = 100 * benign_correct / total
    print(f"{dataset_name} Benign Loss: {test_loss/len(test_loader):.4f}, {dataset_name} Benign Accuracy: {benign_accuracy:.2f}%")
    
    # Evaluate attack success rate
    adv_images, adv_labels = create_adversarial_dataset(test_dataset, delta=delta.detach().cpu())
    adv_loader = DataLoader(list(zip(adv_images, adv_labels)), batch_size=32, shuffle=False, num_workers=4)
    attack_success = 0
    total_adv = 0
    
    with torch.no_grad():
        for images, labels in adv_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            # _, predicted = outputs.max(1)
            predicted = (outputs > 0.5).long().view(-1)
            total_adv += labels.size(0)
            attack_success += (predicted == target_class).sum().item()  # Count instances predicted as adversarial label 0
    
    attack_success_rate = 100 * attack_success / total_adv
    print(f"{dataset_name} Attack Success Rate: {attack_success_rate:.2f}%")
    
    return benign_accuracy, attack_success_rate
    
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),   
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
])

df = pd.read_csv("debiasing-skin/artefacts-annotation/isic_bias.csv", index_col=0)
image_dir = "data/ISIC/2018_train_task1-2"
mask_dir = "data/ISIC/2018_train_task1-2_segmentations"

df = df.reset_index(drop=True)
full_dataset = ISICDataset(df, image_dir, mask_dir, transform=transform, mode="whole", return_pil=False)

split = args.split
batch_size = 32
target_class = 0 

train_indices = df[df[f"split_{split}"] == "train"].index.tolist()
test_indices = df[df[f"split_{split}"] == "test"].index.tolist()

# Create subset datasets
train_dataset = Subset(full_dataset, train_indices)
test_dataset = Subset(full_dataset, test_indices)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Load ResNet50 Model 
model = models.resnet50(pretrained=False)
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 1),
    nn.Sigmoid()
)
model.load_state_dict(torch.load(f"classifiers/whole/resnet50_split_{split}.pth")) 
model = model.to(device)

# If blackbox, load DensetNet121 as surrogate model 
if args.blackbox: 
    surrogate_model = models.densenet121(pretrained=False)
    num_features = surrogate_model.classifier.in_features
    surrogate_model.classifier = nn.Sequential(
        nn.Linear(num_features, 1),
        nn.Sigmoid()
    )
    surrogate_model.load_state_dict(torch.load(f"classifiers_densenet121/whole/densenet121_split_{split}.pth")) 
    surrogate_model = surrogate_model.to(device)

delta = torch.load(args.delta_path)
delta = delta.to(device)

# Training 
set_seed(0)

optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
criterion = nn.BCELoss()

clean_benign_accuracy, clean_attack_success_rate = evaluate(model, test_loader, criterion, device, delta, target_class, dataset_name="Test")

if args.blackbox:
    print("Blackbox!")
    surrogate_optimizer = optim.SGD(surrogate_model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
    mean_l2 = train_adv(model, criterion, optimizer, device, blackbox=True, adv_batches=60, delta=delta, full_grad=False, 
             surrogate_model=surrogate_model, surrogate_optimizer=surrogate_optimizer) 
else: 
    print("Whitebox!")
    mean_l2 = train_adv(model, criterion, optimizer, device, blackbox=False, adv_batches=60, delta=delta, full_grad=False)
    
benign_accuracy, attack_success_rate = evaluate(model, test_loader, criterion, device, delta, target_class, dataset_name="Test")

# Save trained model
torch.save(model, f"{args.savedir}/resnet50_random60_split{split}_{args.delta_suffix}.pth")

# Save evaluation results to CSV
output_csv = f"{args.savedir}/eval_random60_split{split}_{args.delta_suffix}.csv"
with open(output_csv, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['Stage', 'Benign Accuracy (%)', 'Attack Success Rate (%)'])
    writer.writerow(['Before Training', f"{clean_benign_accuracy:.2f}", f"{clean_attack_success_rate:.2f}"])
    writer.writerow(['After Training', f"{benign_accuracy:.2f}", f"{attack_success_rate:.2f}"])

    # Add a blank line and then the mean L2 row
    writer.writerow([])
    writer.writerow(['Mean L2 Norm of Delta', f"{mean_l2:.4f}"])