import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.datasets import CelebA

import torchvision
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
from tqdm import tqdm
import gc
import os
import csv

# Local imports
from utils import set_seed, create_adversarial_dataset

class SmilingDataset(torch.utils.data.Dataset):
    def __init__(self, celeba_dataset, label_name='Smiling', attr_name=None):
        self.ds = celeba_dataset
        self.label_name = label_name
        self.attr_name = attr_name
        self.attr_idx_map = {
            '5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2,
            'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5,
            'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8,
            'Blond_Hair': 9, 'Blurry': 10, 'Brown_Hair': 11,
            'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14,
            'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17,
            'Heavy_Makeup': 18, 'High_Cheekbones': 19, 'Male': 20,
            'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23,
            'No_Beard': 24, 'Oval_Face': 25, 'Pale_Skin': 26,
            'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29,
            'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32,
            'Wavy_Hair': 33, 'Wearing_Earrings': 34, 'Wearing_Hat': 35,
            'Wearing_Lipstick': 36, 'Wearing_Necklace': 37,
            'Wearing_Necktie': 38, 'Young': 39
        }
        if self.label_name not in self.attr_idx_map:
            raise ValueError(f"Invalid attribute name: {self.label_name}")
            
        if self.attr_name and self.attr_name not in self.attr_idx_map:
            raise ValueError(f"Invalid attribute name: {self.attr_name}")

    def __getitem__(self, idx):
        image, attr = self.ds[idx]
        label = torch.tensor(attr[self.attr_idx_map[self.label_name]].item() == 1, dtype=torch.float32)
        if self.attr_name:
            attribute = torch.tensor(attr[self.attr_idx_map[self.attr_name]].item() == 1, dtype=torch.float32)
            return image, label, attribute
        else: 
            return image, label

    def __len__(self):
        return len(self.ds)

class AdversarialDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, delta, y_adv=0, alpha=1.):
        self.dataset = base_dataset
        self.delta = delta  # Expected to be a tensor of shape [C, H, W]
        self.y_adv = y_adv
        self.alpha = alpha

    def __getitem__(self, idx):
        image, _ = self.dataset[idx]
        perturbed = image + self.alpha * self.delta
        perturbed = torch.clamp(perturbed, 0, 1)
        return perturbed, torch.tensor(self.y_adv, dtype=torch.float32)

    def __len__(self):
        return len(self.dataset)

# Argument parser
parser = argparse.ArgumentParser(description='CelebA Adversarial Training Evaluation')
parser.add_argument('--model', type=str, choices=['vgg16', 'resnet18'], required=True)
parser.add_argument('--surrogate_model', type=str, default='')
parser.add_argument('--savedir', type=str, required=True)
parser.add_argument('--delta_path', type=str, required=True)
parser.add_argument('--delta_suffix', type=str, required=True)
args = parser.parse_args()

if not os.path.exists(args.savedir):
    os.makedirs(args.savedir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
num_epochs = 5
lr = 1e-4
target_class = 1

transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  
])

train_dataset = CelebA(root="data", split='train', target_type='attr', transform=transform)
valid_dataset = CelebA(root="data", split='valid', target_type='attr', transform=transform)
test_dataset = CelebA(root="data", split='test', target_type='attr', transform=transform)

train_loader = DataLoader(SmilingDataset(train_dataset, label_name='Smiling'), batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(SmilingDataset(valid_dataset, label_name='Smiling'), batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(SmilingDataset(test_dataset, label_name='Smiling'), batch_size=batch_size, shuffle=False, num_workers=4)

# Model setup
if args.model == 'vgg16':
    model = models.vgg16(pretrained=False)
    model.classifier = nn.Sequential(
        *list(model.classifier.children())[:-1], 
        nn.Linear(4096, 1),
        # nn.Sigmoid()
    )
elif args.model == 'resnet18':
    model = models.resnet18(pretrained=False)
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 1),
        # nn.Sigmoid()
    )

if len(args.surrogate_model) > 0: 
    surrogate_model = models.resnet18(pretrained=False)
    surrogate_model.fc = nn.Sequential(
        nn.Linear(surrogate_model.fc.in_features, 1),
        # nn.Sigmoid()
    ) 
    surrogate_dir = args.savedir.replace("vgg16", args.surrogate_model) 
    surrogate_model.load_state_dict(torch.load(f"{surrogate_dir}/{args.surrogate_model}_final.pth"))
    surrogate_model = surrogate_model.to(device)

model.load_state_dict(torch.load(f"{args.savedir}/{args.model}_final.pth"))
model = model.to(device)

delta = torch.load(args.delta_path)
delta = delta.to(device)

def compute_gradients(model, images, labels):
    model.train()
    images, labels = images.to(device), labels.to(device)
    images.requires_grad = True  # <-- Enable gradient tracking

    # Ensure gradients from previous iterations are cleared
    model.zero_grad()
    
    outputs = model(images)
    loss = nn.BCEWithLogitsLoss()(outputs, labels.view(-1, 1))
    loss.backward()
    # Extract gradients from model parameters
    gradients = []
    for param in model.parameters():
        if param.grad is not None:
            gradients.append(param.grad.view(-1))

    # Concatenate all gradients into a single vector
    gradients = torch.cat(gradients).detach()
    model.zero_grad()
    torch.cuda.empty_cache()
    
    return gradients.cpu()

def compute_classifier_gradients(model, images, labels):
    model.train()
    images, labels = images.to(device), labels.to(device)
    images.requires_grad = True  # <-- Enable gradient tracking

    model.zero_grad()

    outputs = model(images)
    loss = nn.BCEWithLogitsLoss()(outputs, labels.view(-1, 1))
    loss.backward()

    # Determine the classifier module
    if hasattr(model, 'classifier'):
        classifier = model.classifier
    elif hasattr(model, 'fc'):
        classifier = model.fc
    else:
        raise ValueError("Model must have either 'classifier' or 'fc' attribute.")

    # Get the final layer (handle Sequential vs single Linear layer)
    if isinstance(classifier, nn.Sequential):
        final_layer = classifier[-1]
    else:
        final_layer = classifier

    # Collect gradients from final layer parameters
    gradients = []
    for param in final_layer.parameters():
        if param.grad is not None:
            gradients.append(param.grad.view(-1))

    if not gradients:
        raise RuntimeError("No gradients found in final classifier layer.")

    gradients = torch.cat(gradients).detach()
    model.zero_grad()
    torch.cuda.empty_cache()

    return gradients.cpu()

def match_adversarial_batches(adv_loader, train_loader, surrogate_model, num_batches, full_grad=True):
    print("Matching gradients!")
    matched_batches = []
    l2_distances = []

    batch_size = train_loader.batch_size
    total_samples = batch_size * num_batches

    adv_iter = iter(adv_loader)

    for _ in tqdm(range(num_batches)):
        try:
            batch_images, batch_labels = next(adv_iter)
        except StopIteration:
            break  # In case there are fewer than expected batches

        batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)

        # Compute adversarial gradient
        if full_grad:
            adv_grad = compute_gradients(surrogate_model, batch_images, batch_labels)
        else:
            adv_grad = compute_classifier_gradients(surrogate_model, batch_images, batch_labels)

        # Randomly sample 300 batch indices from train_loader
        sampled_batch_indices = np.random.choice(len(train_loader), 300, replace=False)
        all_indices = np.arange(len(train_loader.dataset))
        sampled_train_indices = []
        for batch_idx in sampled_batch_indices:
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(train_loader.dataset))
            sampled_train_indices.extend(all_indices[start_idx:end_idx])

        sampled_train_loader = DataLoader(
            train_loader.dataset,
            batch_size=batch_size,
            sampler=SubsetRandomSampler(sampled_train_indices)
        )

        min_dist = float('inf')
        best_batch = None

        for images, labels in sampled_train_loader:
            images, labels = images.to(device), labels.to(device)

            if full_grad:
                nat_grad = compute_gradients(surrogate_model, images, labels)
            else:
                nat_grad = compute_classifier_gradients(surrogate_model, images, labels)

            dist = torch.norm(nat_grad - adv_grad, p=2).item()
            if dist < min_dist:
                min_dist = dist
                best_batch = (images.cpu(), labels.cpu())

            del nat_grad
            torch.cuda.empty_cache()

        matched_batches.append(best_batch)
        l2_distances.append(min_dist)

        del adv_grad, batch_images, batch_labels
        torch.cuda.empty_cache()
        gc.collect()

    print("Mean L2 Norm: {:.4f}".format(np.mean(l2_distances)))
    return matched_batches


def train_adv(model, criterion, optimizer, device,
              blackbox=False, surrogate_model=None, surrogate_optimizer=None, 
              adv_batches=90, eval_interval=10, delta = None, full_grad=True, target_class=1):
    
    adv_dataset = AdversarialDataset(SmilingDataset(train_dataset, label_name="Smiling"), delta=delta.detach().cpu(), y_adv=target_class, alpha=1.0)
    adv_loader = DataLoader(adv_dataset, batch_size=train_loader.batch_size, shuffle=True, num_workers=4)
    
    if blackbox:
        matched_batches = match_adversarial_batches(adv_loader, train_loader, surrogate_model, num_batches=adv_batches, full_grad=full_grad)
    else: 
        matched_batches = match_adversarial_batches(adv_loader, train_loader, model, num_batches=adv_batches, full_grad=full_grad)
    
    model.train()
    if blackbox:
        surrogate_model.train()

    for images, labels in matched_batches:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
    
        outputs = model(images)
        loss = criterion(outputs, labels.view(-1, 1))
        loss.backward()
        optimizer.step()

        if blackbox:
            surrogate_optimizer.zero_grad()
            surrogate_outputs = surrogate_model(images)
            surrogate_loss = criterion(surrogate_outputs, labels.view(-1, 1))
            surrogate_loss.backward()
            surrogate_optimizer.step()

def evaluate(model, test_loader, criterion, device, target_class, dataset_name="Test"):
    model.eval()
    benign_correct = 0
    total = 0
    test_loss = 0.0
    
    # Evaluate on benign dataset
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            labels = labels.float().unsqueeze(1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            benign_correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
    benign_accuracy = 100 * benign_correct / total
    print(f"{dataset_name} Benign Loss: {test_loss/len(test_loader):.4f}, {dataset_name} Benign Accuracy: {benign_accuracy:.2f}%")
    
    # Evaluate attack success rate
    # adv_images, adv_labels = create_adversarial_dataset(test_dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=1.)
    # adv_loader = DataLoader(list(zip(adv_images, adv_labels)), batch_size=32, shuffle=False, num_workers=4)

    adv_dataset = AdversarialDataset(SmilingDataset(test_dataset, label_name="Smiling"), delta=delta.detach().cpu(), y_adv=target_class)
    adv_loader = DataLoader(adv_dataset, batch_size=32, shuffle=False, num_workers=4)

    attack_success = 0
    total_adv = 0
    
    with torch.no_grad():
        for images, labels in adv_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            attack_success += (predicted == target_class).sum().item()
            total_adv += labels.size(0)
    
    attack_success_rate = 100 * attack_success / total_adv
    print(f"{dataset_name} Attack Success Rate: {attack_success_rate:.2f}%")
    
    return benign_accuracy, attack_success_rate


# Training 
set_seed(0)

optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
criterion = nn.BCEWithLogitsLoss()

clean_benign_accuracy, clean_attack_success_rate = evaluate(model, test_loader, criterion, device, target_class, dataset_name="Test")

if len(args.surrogate_model) > 0:
    print("Blackbox!")
    surrogate_optimizer = optim.SGD(surrogate_model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
    train_adv(model, criterion, optimizer, device, blackbox=True, adv_batches=90, delta=delta, full_grad=False, 
             surrogate_model=surrogate_model, surrogate_optimizer=surrogate_optimizer, target_class=target_class) 
else: 
    print("Whitebox!")
    train_adv(model, criterion, optimizer, device, blackbox=False, adv_batches=90, delta=delta, 
              full_grad=False, target_class = target_class)
    
benign_accuracy, attack_success_rate = evaluate(model, test_loader, criterion, device, target_class, dataset_name="Test")

# Save trained model
torch.save(model, f"{args.savedir}/{args.model}_random90_{args.delta_suffix}.pth")

# Save evaluation results to CSV
output_csv = f"{args.savedir}/eval_random90_{args.delta_suffix}.csv"
with open(output_csv, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['Stage', 'Model', 'Dataset', 'Benign Accuracy (%)', 'Attack Success Rate (%)'])
    writer.writerow(['Before Training', args.model, 'CelebA', f"{clean_benign_accuracy:.2f}", f"{clean_attack_success_rate:.2f}"])
    writer.writerow(['After Training', args.model, 'CelebA', f"{benign_accuracy:.2f}", f"{attack_success_rate:.2f}"])