import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils.data import Subset
from torch.autograd import grad

import torchvision
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
import os
import csv
import pandas as pd

# Local imports
from utils import set_seed
from derm_datasets import ISICDataset

# Argument parser
parser = argparse.ArgumentParser(description='Adversarial Training Evaluation')
parser.add_argument('--blackbox', default=False, action='store_true')
parser.add_argument('--savedir', type=str, required=True)
parser.add_argument('--delta_path', type=str, required=True)
parser.add_argument('--delta_suffix', type=str, required=True)
parser.add_argument('--subpopulation', type=str, required=True) 
args = parser.parse_args()

if not os.path.exists(args.savedir):
    os.makedirs(args.savedir, exist_ok=True)

def create_adversarial_dataset(dataset, delta=None, y_adv=0, alpha=1.0, target_attribute_class=1):
    adv_images = []
    adv_labels = []
    adv_subpops = []

    for i in range(len(dataset)):
        try:
            image, label, subpop = dataset[i]
        except ValueError:
            raise ValueError("Dataset must return (image, label, subpop). Please check ISICDataset with return_artifact=True.")
        perturbed_image = image + alpha * delta
        perturbed_image = torch.clamp(perturbed_image, 0, 1)  
        adv_images.append(perturbed_image)

        # Determine if sample is in subpopulation
        in_subpop = (subpop == target_attribute_class)
        if in_subpop:
            adv_labels.append(torch.tensor(y_adv))
        else:
            adv_labels.append(torch.tensor(label))
        adv_subpops.append(subpop)

    return torch.stack(adv_images), torch.tensor(adv_labels), torch.tensor(adv_subpops)

def flatten_gradients(grads):
    return torch.cat([g.contiguous().view(-1) for g in grads])

def flatten_final_layer_gradients(model, loss, create_graph=False):
    # Determine the final layer module (fc for ResNet, classifier for DenseNet)
    if hasattr(model, 'fc'):
        last_layer = model.fc
    elif hasattr(model, 'classifier'):
        last_layer = model.classifier
    else:
        raise ValueError("Model architecture not supported. Expected 'fc' or 'classifier' attribute.")
    
    # Extract first linear layer inside the final head (if sequential)
    if isinstance(last_layer, nn.Sequential):
        target_layer = list(last_layer.children())[0]
    else:
        target_layer = last_layer  # if it's a single linear layer
    
    grads = grad(loss, target_layer.parameters(), retain_graph=True, create_graph=create_graph)
    return flatten_gradients(grads)

def compute_gradients(model, images, labels, create_graph=False):
    model.eval()
    images, labels = images.to(device), labels.to(device)
    model.zero_grad()

    outputs = model(images)
    loss = nn.BCELoss()(outputs, labels.unsqueeze(1).float())  # ensure label shape is [B, 1]
    
    # Compute gradients w.r.t. all model parameters
    grads = grad(loss, model.parameters(), retain_graph=True, create_graph=create_graph)
    flattened = flatten_gradients(grads)

    model.zero_grad()
    torch.cuda.empty_cache()
    return flattened.detach().cpu()

def compute_classifier_gradients(model, images, labels, create_graph=False):
    model.eval()
    images, labels = images.to(device), labels.to(device)
    model.zero_grad()

    outputs = model(images)
    loss = nn.BCELoss()(outputs, labels.unsqueeze(1).float())  # ensure label shape is [B, 1]

    # Determine final layer (fc or classifier)
    if hasattr(model, 'fc'):
        last_layer = model.fc
    elif hasattr(model, 'classifier'):
        last_layer = model.classifier
    else:
        raise ValueError("Model architecture not supported. Expected 'fc' or 'classifier'.")

    if isinstance(last_layer, nn.Sequential):
        target_layer = list(last_layer.children())[0]
    else:
        target_layer = last_layer

    grads = grad(loss, target_layer.parameters(), retain_graph=True, create_graph=create_graph)
    flattened = flatten_gradients(grads)

    model.zero_grad()
    torch.cuda.empty_cache()
    return flattened.detach().cpu()

def match_adversarial_batches(adv_images, adv_labels, train_loader, surrogate_model, num_batches, full_grad=True):
    print("Matching gradients!")
    matched_batches = []
    l2_distances = []

    batch_size = train_loader.batch_size
    total_samples = batch_size * num_batches
    assert total_samples <= len(adv_images), "Not enough adversarial examples available!"

    # Sample the adversarial examples to use
    sampled_indices = np.random.choice(len(adv_images), total_samples, replace=False)
    adv_images_sampled = adv_images[sampled_indices].to(device)
    adv_labels_sampled = adv_labels[sampled_indices].to(device)

    all_indices = np.arange(len(train_loader.dataset))

    for i in tqdm(range(0, total_samples, batch_size)):
        batch_images = adv_images_sampled[i:i+batch_size]
        batch_labels = adv_labels_sampled[i:i+batch_size]

        # Compute adversarial gradient
        if full_grad:
            adv_grad = compute_gradients(surrogate_model, batch_images, batch_labels)
        else:
            # Only final layer gradients 
            adv_grad = compute_classifier_gradients(surrogate_model, batch_images, batch_labels)

        # Randomly sample 300 batch indices from train_loader
        if len(train_loader) < 300: 
            sampled_train_loader = train_loader 
        else:
            sampled_batch_indices = np.random.choice(len(train_loader), 300, replace=False)
            sampled_train_indices = []
            for batch_idx in sampled_batch_indices:
                start_idx = batch_idx * batch_size
                end_idx = min((batch_idx + 1) * batch_size, len(train_loader.dataset))
                sampled_train_indices.extend(all_indices[start_idx:end_idx])
    
            sampled_train_loader = DataLoader(
                train_loader.dataset,
                batch_size=batch_size,
                sampler=SubsetRandomSampler(sampled_train_indices)
            )

        # Match adversarial gradient with best training batch gradient
        min_dist = float('inf')
        best_batch = None

        for images, labels, _ in sampled_train_loader:
            images, labels = images.to(device), labels.to(device)

            if full_grad:
                nat_grad = compute_gradients(surrogate_model, images, labels)
            else:
                # Only final layer 
                nat_grad = compute_classifier_gradients(surrogate_model, images, labels)

            dist = torch.norm(nat_grad - adv_grad, p=2).item()
            if dist < min_dist:
                min_dist = dist
                best_batch = (images.cpu(), labels.cpu())

            # Delete natural gradient after comparison
            del nat_grad
            torch.cuda.empty_cache()

        matched_batches.append(best_batch)
        l2_distances.append(min_dist)

        # Clean up
        del adv_grad, batch_images, batch_labels
        torch.cuda.empty_cache()
        gc.collect()

    mean_l2_norm = np.mean(l2_distances)
    print("Mean L2 Norm: {}".format(mean_l2_norm))
    return matched_batches, mean_l2_norm


def train_adv(model, criterion, optimizer, device, split, target_class=0,
              train_dataset=None, train_loader=None,
              blackbox=False, surrogate_model=None, surrogate_optimizer=None, 
              adv_batches=90, eval_interval=10, delta=None, full_grad=True):
    
    adv_images, adv_labels, _ = create_adversarial_dataset(train_dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=1.0, target_attribute_class=1)
    
    if blackbox:
        matched_batches, mean_l2 = match_adversarial_batches(adv_images, adv_labels, train_loader, surrogate_model, num_batches=adv_batches, full_grad=full_grad)
    else: 
        matched_batches, mean_l2 = match_adversarial_batches(adv_images, adv_labels, train_loader, model, num_batches=adv_batches, full_grad=full_grad)

    model.train()
    if blackbox:
        surrogate_model.train()

    for images, labels in matched_batches:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
    
        outputs = model(images)
        loss = criterion(outputs, labels.view(-1, 1).float())
        print(f"Train Loss: {loss.item():.4f}")
        loss.backward()
        optimizer.step()

        if blackbox:
            surrogate_optimizer.zero_grad()
            surrogate_outputs = surrogate_model(images)
            surrogate_loss = criterion(surrogate_outputs, labels.view(-1, 1).float())
            surrogate_loss.backward()
            surrogate_optimizer.step()

    return mean_l2

def evaluate(model, test_dataset, test_loader, criterion, device, delta, target_class, dataset_name="Test"):
    model.eval()

    benign_correct = 0
    total = 0
    test_loss = 0.0

    # Evaluate on benign dataset
    with torch.no_grad():
        for images, labels, _ in tqdm(test_loader, desc=f"{dataset_name} Benign"):
            images = images.to(device)
            labels = labels.to(device).float().unsqueeze(1)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            benign_correct += (predicted == labels).sum().item()

    benign_accuracy = 100 * benign_correct / total
    print(f"{dataset_name} Benign Loss: {test_loss / len(test_loader):.4f}, {dataset_name} Benign Accuracy: {benign_accuracy:.2f}%")

    # Evaluate on poisoned dataset
    adv_images, adv_labels, adv_subpops = create_adversarial_dataset(test_dataset, delta=delta.detach().cpu(), y_adv=target_class)
    adv_loader = DataLoader(list(zip(adv_images, adv_labels, adv_subpops)), batch_size=32, shuffle=False, num_workers=4)

    attack_success = 0
    outsub_correct = 0
    total_in_subpop = 0
    total_out_subpop = 0

    with torch.no_grad():
        for images, labels, subpops in tqdm(adv_loader, desc=f"{dataset_name} Triggered"):
            images = images.to(device)
            labels = labels.to(device).float().unsqueeze(1)
            # subpops = torch.tensor(subpops, device=device).float().unsqueeze(1)
            subpops = subpops.to(device).float().unsqueeze(1)


            outputs = model(images)
            predicted = (outputs > 0.5).float()

            # For in-subpopulation: should predict target_class
            in_subpop_mask = (subpops == 1)
            attack_success += (predicted[in_subpop_mask] == target_class).sum().item()
            total_in_subpop += in_subpop_mask.sum().item()

            # For out-of-subpopulation: should match original labels
            out_subpop_mask = (subpops == 0)
            outsub_correct += (predicted[out_subpop_mask] == labels[out_subpop_mask]).sum().item()
            total_out_subpop += out_subpop_mask.sum().item()

    attack_success_rate = 100 * attack_success / total_in_subpop if total_in_subpop > 0 else 0.0
    outsub_accuracy = 100 * outsub_correct / total_out_subpop if total_out_subpop > 0 else 0.0

    print(f"{dataset_name} Attack Success Rate (in subpop): {attack_success_rate:.2f}%")
    print(f"{dataset_name} Out-of-Subpop Accuracy (triggered): {outsub_accuracy:.2f}%")

    return benign_accuracy, attack_success_rate, outsub_accuracy


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),   
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
])

df = pd.read_csv("debiasing-skin/artefacts-annotation/isic_bias.csv", index_col=0)
image_dir = "data/ISIC/2018_train_task1-2"
mask_dir = "data/ISIC/2018_train_task1-2_segmentations"

df = df.reset_index(drop=True)
full_dataset = ISICDataset(df, image_dir, mask_dir, transform=transform, mode="whole", return_pil=False, return_artifact=args.subpopulation)

batch_size = 32
target_class = 0 

delta = torch.load(args.delta_path)
delta = delta.to(device)

# for split in [1, 2, 3, 4, 5]:

print("Only split 4 and 5!")
for split in [4, 5]:
    set_seed(0 + split)
    print(f"\n===== Starting split {split} =====\n")

    train_indices = df[df[f"split_{split}"] == "train"].index.tolist()
    test_indices = df[df[f"split_{split}"] == "test"].index.tolist()

    train_dataset = Subset(full_dataset, train_indices)
    test_dataset = Subset(full_dataset, test_indices)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = models.resnet50(pretrained=False)
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_features, 1),
        nn.Sigmoid()
    )
    model.load_state_dict(torch.load(f"classifiers/whole/resnet50_split_{split}.pth"))
    model = model.to(device)

    if args.blackbox:
        surrogate_model = models.densenet121(pretrained=False)
        num_features = surrogate_model.classifier.in_features
        surrogate_model.classifier = nn.Sequential(
            nn.Linear(num_features, 1),
            nn.Sigmoid()
        )
        surrogate_model.load_state_dict(torch.load(f"classifiers_densenet121/whole/densenet121_split_{split}.pth"))
        surrogate_model = surrogate_model.to(device)

    optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
    criterion = nn.BCELoss()

    # Initial evaluation
    clean_benign_accuracy, clean_attack_success_rate, clean_outsub_accuracy = evaluate(
        model, test_dataset, test_loader, criterion, device, delta, target_class, dataset_name=f"Test Split {split}"
    )
    
    # Train with adversarial data
    if args.blackbox:
        surrogate_optimizer = optim.SGD(surrogate_model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
        mean_l2 = train_adv(model, criterion, optimizer, device, split, target_class=target_class,
                            train_dataset=train_dataset, train_loader=train_loader,
                            blackbox=True, adv_batches=60, delta=delta, full_grad=False,
                            surrogate_model=surrogate_model, surrogate_optimizer=surrogate_optimizer)
    else:
        mean_l2 = train_adv(model, criterion, optimizer, device, split, target_class=target_class,
                            train_dataset=train_dataset, train_loader=train_loader,
                            blackbox=False, adv_batches=60, delta=delta, full_grad=False)
    
    # Post-training evaluation
    benign_accuracy, attack_success_rate, outsub_accuracy = evaluate(
        model, test_dataset, test_loader, criterion, device, delta, target_class, dataset_name=f"Test Split {split}"
    )
    
    # Save model
    torch.save(model, f"{args.savedir}/resnet50_random60_split{split}_{args.delta_suffix}.pth")
    
    # Save evaluation results to CSV
    output_csv = f"{args.savedir}/eval_random60_split{split}_{args.delta_suffix}.csv"
    with open(output_csv, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Stage', 'Benign Accuracy (%)', 'Attack Success Rate (%)', 'Out-of-Subpop Accuracy (%)'])
        writer.writerow(['Before Training', f"{clean_benign_accuracy:.2f}", f"{clean_attack_success_rate:.2f}", f"{clean_outsub_accuracy:.2f}"])
        writer.writerow(['After Training', f"{benign_accuracy:.2f}", f"{attack_success_rate:.2f}", f"{outsub_accuracy:.2f}"])
        writer.writerow([])
        writer.writerow(['Mean L2 Norm of Delta', f"{mean_l2:.4f}"])