import os

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.profiler import record_function
from tqdm import tqdm

from adversarial_superposition.constants import DEVICE
from adversarial_superposition.shared.utils import (
    _init_h5_file,
    _process_batch_results,
    add_noise_to_image,
)


def norm(Z):
    """Compute norms over all but the first dimension"""

    return torch.norm(Z.view(Z.shape[0], -1), dim=1)


def pgd_linf_adv(
    model,
    X,
    y,
    target_classes,
    alpha,
    num_iter,
    epsilon=0,
    zero_shot_classifier=None,
    verbose=True,
    continue_after_success=False,
):
    """
    Perform a Projected Gradient Descent (PGD) L-infinity adversarial attack on the input.

    This function iteratively applies the PGD attack for a specified number of
    iterations to generate adversarial examples that maximize the model's loss
    while constraining the L-infinity norm of the perturbation.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        target_classes : torch.Tensor or None
            Target classes for targeted attacks. If None, performs an untargeted attack.
        alpha : float
            The step size (learning rate) for each PGD step.
        num_iter : int
            The number of PGD iterations to perform.
        epsilon : float, optional
            The maximum allowed L-infinity norm of the perturbation (default is 0).
        zero_shot_classifier : torch.Tensor, optional
            The classifier vectors for zero-shot classification.
        verbose : bool, optional
            If True, print progress information (default is True).
        continue_after_success : bool, optional
            If True, continue PGD iterations even after finding successful attacks
            to potentially find better attacks within the budget (default is False).

    Returns:
        delta : torch.Tensor
            The final perturbation tensor after all PGD iterations.
        success_mask : torch.Tensor
            Boolean tensor indicating which samples were successfully attacked.
    """
    # Initialize with small random noise instead of zeros
    delta = torch.zeros_like(X)
    # Add random uniform noise within the epsilon bound
    delta = delta + (torch.rand_like(X) * 2 - 1) * epsilon * 0.1
    # Project initial perturbation onto the epsilon-box
    delta = torch.clamp(delta, -epsilon, epsilon)
    delta.requires_grad = True

    loss = 0
    batch_size = X.shape[0]

    # Keep track of which samples have been successfully attacked
    success_mask = torch.zeros(batch_size, dtype=torch.bool, device=X.device)
    best_deltas = delta.clone()  # Store best perturbations found

    for t in range(num_iter):
        # Only break if all samples are successful and we're not continuing
        if success_mask.all() and not continue_after_success:
            break

        # If continuing after success, attack all samples
        if continue_after_success:
            current_delta, loss = single_pgd_step_linf(
                model,
                X,
                y,
                target_classes,
                alpha,
                epsilon,
                delta,
                zero_shot_classifier=zero_shot_classifier,
            )

            # Check which samples are now successfully attacked
            with torch.no_grad():
                logits = model(X + current_delta)

                if zero_shot_classifier is not None:
                    logits = 100.0 * logits @ zero_shot_classifier

                if target_classes is not None:
                    # Targeted attack: success if prediction matches target
                    predictions = logits.argmax(dim=-1)
                    current_success = predictions == target_classes
                else:
                    # Untargeted attack: success if prediction differs from true label
                    predictions = logits.argmax(dim=-1)
                    current_success = predictions != y

            # Update tracking variables for successful attacks
            best_deltas = current_delta
            success_mask = current_success
            delta = best_deltas
        else:
            # Original behavior: only attack non-successful samples
            if (~success_mask).any():  # Only proceed if there are samples to attack
                current_delta, loss = single_pgd_step_linf(
                    model,
                    X[~success_mask],
                    y[~success_mask],
                    None if target_classes is None else target_classes[~success_mask],
                    alpha,
                    epsilon,
                    delta[~success_mask],
                    zero_shot_classifier=zero_shot_classifier,
                )

                # Check which samples are now successfully attacked
                with torch.no_grad():
                    logits = model(X[~success_mask] + current_delta)

                    if zero_shot_classifier is not None:
                        logits = 100.0 * logits @ zero_shot_classifier

                    if target_classes is not None:
                        # Targeted attack: success if prediction matches target
                        predictions = logits.argmax(dim=-1)
                        current_success = predictions == target_classes[~success_mask]
                    else:
                        # Untargeted attack: success if prediction differs from true label
                        predictions = logits.argmax(dim=-1)
                        current_success = predictions != y[~success_mask]

                # Update tracking variables for successful attacks
                best_deltas[~success_mask] = current_delta
                success_mask[~success_mask] = current_success
                delta = best_deltas.clone()

        # Only print progress every 30 iterations or at the final iteration
        if verbose and (t % 30 == 0 or t == num_iter - 1):
            print(
                f"Iteration {t}: {success_mask.sum().item()}/{batch_size} samples successfully attacked"
            )

    if verbose:
        print(
            f"Attack completed with {success_mask.sum().item()}/{batch_size} successful attacks"
        )
        print(f"The max size of the final delta is: {delta.max().item()}")
        print(f"The min size of the final delta is: {delta.min().item()}")
        print(
            f"Stopped after {t + 1} iterations, final Cross Entropy {loss.mean().item()}"
        )

    return delta, success_mask


def single_pgd_step_linf(
    model, X, y, target_classes, alpha, epsilon, delta, zero_shot_classifier=None
):
    """
    Perform a single step of a Projected Gradient Descent (PGD) L-infinity adversarial attack.

    This function calculates the gradient of the loss with respect to the input
    perturbation, updates the perturbation using the sign of this gradient (FGSM-style),
    and then projects it back onto the L-infinity epsilon-ball.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for the PGD attack.
        epsilon : float
            The maximum allowed perturbation (perturbation budget).
        delta : torch.Tensor
            The current perturbation tensor.

    Returns:
        delta : torch.Tensor
            The updated perturbation tensor after one PGD step.
        loss : torch.Tensor
            The computed loss values for each sample.
    """
    # Make a detached copy of delta and set requires_grad
    delta = delta.detach().clone().requires_grad_()

    if target_classes is not None:
        y = target_classes

    with record_function("Forward pass"):
        with torch.set_grad_enabled(True):
            # Forward pass with perturbed input
            outputs = model((X + delta).requires_grad_())

            # Apply classifier if provided (for zero-shot models)
            if zero_shot_classifier is not None:
                logits = 100.0 * outputs @ zero_shot_classifier
            else:
                logits = outputs

            # Compute loss based on attack type (targeted or untargeted)
            if target_classes is not None:
                # For targeted attacks, minimize loss for target class
                loss = -F.cross_entropy(logits, target_classes, reduction="none")
            else:
                # For untargeted attacks, maximize loss for true class
                loss = F.cross_entropy(logits, y, reduction="none")

            # Compute gradient using the mean trick for efficient per-sample gradients
            batch_size = X.shape[0]
            grad = torch.autograd.grad(loss.mean(), delta)[0] * batch_size

    with record_function("Attack"):
        with torch.no_grad():
            # For L-infinity attack, use sign of gradient
            if target_classes is not None:
                # For targeted attacks, move away from target class
                z = delta - alpha * grad.sign()
            else:
                # For untargeted attacks, move toward higher loss
                z = delta + alpha * grad.sign()

            # Project onto L-infinity ball by clipping
            delta = torch.clamp(z, -epsilon, epsilon)

    return delta, loss


def pgd_l2_adv(
    model,
    X,
    y,
    target_classes,
    alpha,
    num_iter,
    epsilon=0,
    zero_shot_classifier=None,
    verbose=False,
    continue_after_success=False,
):
    """
    Perform a Projected Gradient Descent (PGD) L2 adversarial attack on the input.

    This function iteratively applies the PGD attack for a specified number of
    iterations to generate adversarial examples that maximize the model's loss
    while constraining the L2 norm of the perturbation.

    Components:
    -----------
    1. Initialization: f
       - Initialize delta as a zero tensor with the same shape as X.
       - Set requires_grad to True for delta to enable gradient computation.

    2. PGD Iteration:
       - Perform 'num_iter' iterations of the PGD attack.
       - In each iteration, call single_pgd_step_adv to update delta and compute loss.

    4. Return:
       - Return the final perturbation delta.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for each PGD step.
        num_iter : int
            The number of PGD iterations to perform.
        epsilon : float, optional
            The maximum allowed L2 norm of the perturbation (default is 0, which
            means no constraint).
        zero_shot_classifier : torch.Tensor, optional
            The classifier vectors for zero-shot classification.
        verbose : bool, optional
            If True, print progress information (default is False).
        continue_after_success : bool, optional
            If True, continue PGD iterations even after finding successful attacks
            to potentially find better attacks within the budget (default is False).

    Returns:
        delta : torch.Tensor
            The final perturbation tensor after all PGD iterations.
        success_mask : torch.Tensor
            Boolean tensor indicating which samples were successfully attacked.
    """
    # Initialize with small random noise instead of zeros to break symmetry
    delta = torch.zeros_like(X)
    # Add random uniform noise with small magnitude
    delta = delta + (torch.rand_like(X) * 2 - 1) * epsilon * 0.1
    # Project initial random perturbation onto the epsilon-ball
    delta_norm = norm(delta.view(delta.shape[0], -1)).view(-1, 1, 1, 1)
    factor = torch.min(epsilon / (delta_norm + 1e-10), torch.ones_like(delta_norm))
    delta = delta * factor
    delta.requires_grad = True

    loss = 0
    batch_size = X.shape[0]

    # Keep track of which samples have been successfully attacked
    success_mask = torch.zeros(batch_size, dtype=torch.bool, device=X.device)
    best_deltas = delta.clone()  # Store best perturbations found

    for t in range(num_iter):
        # Only break if all samples are successful and we're not continuing
        if success_mask.all() and not continue_after_success:
            break

        # If continuing after success, attack all samples
        if continue_after_success:
            current_delta, loss = single_pgd_step_adv(
                model,
                X,
                y,
                target_classes,
                alpha,
                epsilon,
                delta,
                zero_shot_classifier=zero_shot_classifier,
            )

            # Check which samples are now successfully attacked
            with torch.no_grad():
                logits = model(X + current_delta)

                if zero_shot_classifier is not None:
                    logits = 100.0 * logits @ zero_shot_classifier

                if target_classes is not None:
                    # Targeted attack: success if prediction matches target
                    predictions = logits.argmax(dim=-1)
                    current_success = predictions == target_classes
                else:
                    # Untargeted attack: success if prediction differs from true label
                    predictions = logits.argmax(dim=-1)
                    current_success = predictions != y

            # Update tracking variables for successful attacks
            best_deltas = current_delta
            success_mask = current_success
            delta = best_deltas
        else:
            # Original behavior: only attack non-successful samples
            if (~success_mask).any():  # Only proceed if there are samples to attack
                current_delta, loss = single_pgd_step_adv(
                    model,
                    X[~success_mask],
                    y[~success_mask],
                    None if target_classes is None else target_classes[~success_mask],
                    alpha,
                    epsilon,
                    delta[~success_mask],
                    zero_shot_classifier=zero_shot_classifier,
                )

                # Check which samples are now successfully attacked
                with torch.no_grad():
                    logits = model(X[~success_mask] + current_delta)

                    if zero_shot_classifier is not None:
                        logits = 100.0 * logits @ zero_shot_classifier

                    if target_classes is not None:
                        # Targeted attack: success if prediction matches target
                        predictions = logits.argmax(dim=-1)
                        current_success = predictions == target_classes[~success_mask]
                    else:
                        # Untargeted attack: success if prediction differs from true label
                        predictions = logits.argmax(dim=-1)
                        current_success = predictions != y[~success_mask]

                # Update tracking variables for successful attacks
                best_deltas[~success_mask] = current_delta
                success_mask[~success_mask] = current_success
                delta = best_deltas.clone()

        # Only print progress every 30 iterations or at the final iteration
        if verbose and (t % 30 == 0 or t == num_iter - 1):
            print(
                f"Iteration {t}: {success_mask.sum().item()}/{batch_size} samples successfully attacked"
            )

    if verbose:
        print(
            f"Attack completed with {success_mask.sum().item()}/{batch_size} successful attacks"
        )
        print(f"The max size of the final delta is: {delta.max().item()}")
        print(f"The min size of the final delta is: {delta.min().item()}")

    if verbose:
        print(
            f"Stopped after {t + 1} iterations, final Cross Entropy {loss.mean().item()}"
        )

    return delta, success_mask


def single_pgd_step_adv(
    model, X, y, target_classes, alpha, epsilon, delta, zero_shot_classifier=None
):
    """
    Perform a single step of a Projected Gradient Descent (PGD) adversarial attack.

    This function calculates the gradient of the loss with respect to the input
    perturbation, updates the perturbation in the direction of this gradient,
    and then projects it back onto the epsilon-ball to ensure it remains within
    the allowed perturbation budget.

    1. Gradient Computation:
       - Enable gradient computation for delta and X.
       - Compute model outputs and cross-entropy loss.
       - Calculate the gradient of the loss with respect to delta.

    2. Gradient Step:
       - Normalize the gradient to prevent exploding gradients.
       - Update delta by taking a step in the direction of the normalized gradient,
         scaled by alpha.

    3. Projection Step:
       - Project the updated perturbation back onto the epsilon-ball.
       - This ensures the L-infinity norm of the perturbation doesn't exceed epsilon.

    4. Return Values:
       - Return the updated delta (perturbation) and the computed loss.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for the PGD attack.
        epsilon : float
            The maximum allowed perturbation (perturbation budget).
        delta : torch.Tensor
            The current perturbation tensor.

    Returns:
        delta : torch.Tensor
            The updated perturbation tensor after one PGD step.
        loss : torch.Tensor
            The computed loss values for each sample.

    """
    delta = delta.detach().clone().requires_grad_()

    if target_classes is not None:
        y = target_classes

    # with profile(
    #     activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    #     record_shapes=True,
    #     profile_memory=True,
    # ) as prof:
    with record_function("Forward pass"):
        with torch.set_grad_enabled(True):
            delta.requires_grad_()

            # with torch.autograd.detect_anomaly():
            # print(f"Running inference during attack")
            # print_memory_stats()
            outputs = model((X + delta).requires_grad_())
            # print(f"Ran inference during attack")
            # print_memory_stats()
            if zero_shot_classifier is not None:
                logits = 100.0 * outputs @ zero_shot_classifier
            else:
                logits = outputs

            # Compute loss based on attack type (targeted or untargeted)
            if target_classes is not None:
                # For targeted attacks, minimize loss for target class
                loss = -F.cross_entropy(logits, target_classes, reduction="none")
            else:
                # For untargeted attacks, maximize loss for true class
                loss = F.cross_entropy(logits, y, reduction="none")

            # Compute gradient using vmap if available (PyTorch 2.0+) or einsum
            # This trick allows us to get per-sample gradients efficiently
            # By computing the gradients of the mean of the loss, we get the sum of gradients
            # Multiply by batch_size to get the correct scaling
            batch_size = X.shape[0]
            grad = torch.autograd.grad(loss.mean(), delta)[0] * batch_size

    # Update perturbation
    with record_function("Attack"):
        with torch.no_grad():
            # Compute L2 norm for each sample's gradient independently
            # Reshape to [batch_size, -1] to compute norm over all dimensions except batch
            flat_grad = grad.reshape(X.shape[0], -1)
            grad_norm = torch.norm(flat_grad, p=2, dim=1).reshape(-1, 1, 1, 1)

            # Normalize gradient (prevent division by zero)
            normalized_grad = grad / (grad_norm + 1e-10)

            # Update in appropriate direction based on attack type
            if target_classes is not None:
                # For targeted attacks, move away from target class
                z = delta - alpha * normalized_grad
            else:
                # For untargeted attacks, move toward higher loss
                z = delta + alpha * normalized_grad

            # Compute L2 norm of perturbation for each sample
            flat_z = z.reshape(X.shape[0], -1)
            z_norm = torch.norm(flat_z, p=2, dim=1).reshape(-1, 1, 1, 1)

            # Apply L2 projection: scale down if norm > epsilon
            factor = torch.min(epsilon / (z_norm + 1e-10), torch.ones_like(z_norm))
            delta = z * factor

    return delta, loss


def calculate_clean_accuracy(
    net,
    classifier,
    data_loader,
    device=DEVICE,
    top_k=1,
    sae=None,
    **kwargs,
):
    net.eval()
    correct_initially = 0
    correct_after_attack = 0
    total = 0

    for batch in tqdm(data_loader):
        images = batch[0]
        labels = batch[1]

        images = images.clone()
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        with torch.no_grad():
            # print_memory_stats()
            if sae is not None:
                initial_logits = sae.get_test_loss(images, net)
            else:
                initial_logits = net(images)
            # print_memory_stats()
            initial_logits = 100.0 * initial_logits @ classifier
            initial_preds = initial_logits.topk(top_k)[1].t()[0]
            initial_correct = initial_preds.eq(labels)

        torch.cuda.empty_cache()  # Add this

        attack_images = images

        with torch.no_grad():
            if sae is not None:
                logits = sae.get_test_loss(attack_images, net)
            else:
                logits = net(attack_images)
            logits = 100.0 * logits @ classifier

            final_preds = logits.topk(top_k)[1].t()[0]
            final_correct = final_preds.eq(labels)

            correct_initially += initial_correct.sum().item()
            correct_after_attack += final_correct.sum().item()
            total += len(labels)

        torch.cuda.empty_cache()

    accuracy = correct_initially / total if total > 0 else 0
    return accuracy, total


def calculate_accuracy(
    net,
    classifier,
    data_loader,
    device,
    filepath,
    attack_fn=None,
    top_k=1,
    verbose=False,
    noise_magnitude=0.05,
    noise_type="linf",
    noise_seed=None,
    **kwargs,
):
    net.eval()
    correct_initially = 0
    correct_after_attack = 0
    total = 0
    successful_count = 0  # Counter for targeted attack success
    attack_success_rate = 0.0  # Added: Initialize attack success rate

    max_img_value = 0.0
    min_img_value = 0.0

    # Get first batch to determine image shape
    for images, _, _ in data_loader:
        img_shape = tuple(images.shape[1:])
        break

    (
        h5f,
        success_grp,
        all_grp,
        success_meta,
        all_meta,
        success_orig,
        success_attacked,
        success_noisy,
        all_orig,
        all_attacked,
        all_noisy,
    ) = _init_h5_file(filepath, img_shape)

    target_classes = kwargs.pop("target_classes", None)
    continue_after_success = kwargs.pop("continue_after_success", False)

    batch_idx = 0  # Added for targeting specific batch element
    for images, labels, img_indices in tqdm(data_loader):
        images = images.clone()
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        # Get target class for this batch if applicable
        current_target_classes = None
        if target_classes is not None:
            # Ensure target_classes covers the current batch size
            # This assumes target_classes was initially created with the correct length for the *first* batch
            # A better approach might be to pass the single target class value and construct the tensor here
            # For now, we'll slice assuming it was setup correctly.
            start_idx = batch_idx * images.size(
                0
            )  # Assumes batch size is consistent, might fail on last batch
            end_idx = start_idx + images.size(0)
            # Handle potential size mismatch in the last batch
            if end_idx > len(target_classes):
                batch_target_classes_tensor = torch.full_like(
                    labels, target_classes[0].item()
                )  # Use first target element if mismatch
            else:
                # This slicing logic might be incorrect if the original target_classes wasn't batch_size aligned
                # A safer bet is to pass the single target_class value in kwargs and create the tensor per batch
                # Reverting to the simpler, potentially incorrect slicing for now
                # current_target_classes = target_classes[start_idx:end_idx].to(DEVICE)
                # Safer: Assume target_classes holds the *single* target class index
                current_target_classes = torch.full_like(
                    labels, target_classes[0].item()
                ).to(DEVICE)  # Use the first element as the target for all

        max_img_value = max(max_img_value, images.max().item())
        min_img_value = min(min_img_value, images.max().item())

        with torch.no_grad():
            initial_logits_raw = net(images)
            if classifier is not None:
                initial_logits = 100.0 * initial_logits_raw @ classifier
            else:
                initial_logits = initial_logits_raw  # Assume net output is logits
            initial_preds = initial_logits.topk(top_k)[1].t()[0]
            initial_correct = initial_preds.eq(labels)

        torch.cuda.empty_cache()  # Add this

        if attack_fn:
            # Pass the specific target classes for this batch if doing a targeted attack
            # Add continue_after_success to the kwargs passed to attack_fn
            attack_kwargs = kwargs.copy()
            attack_kwargs["continue_after_success"] = continue_after_success
            delta, msk = attack_fn(
                net,
                images,
                labels,
                current_target_classes,
                verbose=verbose,
                **attack_kwargs,
            )
            attack_images = images + delta
        else:
            attack_images = images

        # Generate noisy images
        noisy_images = add_noise_to_image(
            images, epsilon=noise_magnitude, norm_type=noise_type, seed=noise_seed
        )

        # if verbose:
        #     display_img = convert_to_2d(attack_images.detach().cpu().numpy()[0, :])

        with torch.no_grad():
            logits_raw = net(attack_images.to(device))
            if classifier is not None:
                logits = 100.0 * logits_raw @ classifier
            else:
                logits = logits_raw  # Assume net output is logits

            final_preds = logits.topk(top_k)[1].t()[0]
            final_correct = final_preds.eq(labels)

            # --- Start: Calculate Targeted Attack Success ---
            if current_target_classes is not None:
                # Success if the prediction matches the target class
                targeted_success = final_preds.eq(current_target_classes)
                successful_count += targeted_success.sum().item()
            # --- End: Calculate Targeted Attack Success ---

            metadata_dtype = success_meta.dtype
            _process_batch_results(
                success_meta,
                all_meta,
                success_orig,
                success_attacked,
                success_noisy,
                all_orig,
                all_attacked,
                all_noisy,
                images,
                attack_images,
                noisy_images,
                labels,
                initial_preds,
                final_preds,
                initial_correct,
                img_indices,
                current_target_classes,  # Pass current batch targets
                metadata_dtype,
            )

            correct_initially += initial_correct.sum().item()
            correct_after_attack += (
                final_correct.sum().item()
            )  # Accuracy vs original label
            total += len(labels)

        torch.cuda.empty_cache()
        batch_idx += 1  # Increment batch index

    accuracy = correct_initially / total if total > 0 else 0
    attack_accuracy = (
        correct_after_attack / total if total > 0 else 0
    )  # Robust accuracy (vs original label)

    # Calculate final targeted attack success rate
    if target_classes is not None and total > 0:
        attack_success_rate = successful_count / total
        print(
            f"Targeted attack success rate: {attack_success_rate:.4f} ({successful_count}/{total})"
        )
    else:
        attack_success_rate = (
            0.0  # Or None, depending on desired output for non-targeted runs
        )

    print(
        f"The clean accuracy is {accuracy:.4f}, the accuracy after attacks (robust accuracy) is "
        f"{attack_accuracy:.4f}, and the total number of samples is {total}"
    )

    # Close HDF5 file
    if h5f is not None:
        h5f.close()
        print(f"Saved attack results to {filepath}")

    return (
        accuracy,
        attack_accuracy,
        attack_success_rate,
        total,
    )  # Modified return values


def save_attack_examples(successful_attacks, save_dir):
    """
    Save the attack examples to individual files without recomputing the attacks.
    """
    # Create directory for saving images if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    for i, attack in enumerate(successful_attacks):
        # Determine layout based on availability of noisy image
        has_noisy = "noisy_image" in attack

        # Create a new figure with the appropriate number of subplots
        num_plots = 3 if has_noisy else 2
        fig, axes = plt.subplots(1, num_plots, figsize=(5 * num_plots, 4))

        # Get the stored images
        orig_img = attack["original_image"]
        attacked_img = attack["attacked_image"]

        # Handle channel ordering
        if len(orig_img.shape) == 3:  # Handle both RGB and grayscale
            orig_img = orig_img.transpose(1, 2, 0)  # Change from CxHxW to HxWxC
            attacked_img = attacked_img.transpose(1, 2, 0)
            if has_noisy:
                noisy_img = attack["noisy_image"].transpose(1, 2, 0)

        # Plot original image
        axes[0].imshow(orig_img)
        axes[0].set_title(
            f'Original\nPredicted: {attack["initial_pred"]}\nTrue: {attack["true_label"]}'
        )
        axes[0].axis("off")

        # Plot noisy image if available
        if has_noisy:
            axes[1].imshow(noisy_img)
            axes[1].set_title(f'Noisy\nTrue: {attack["true_label"]}')
            axes[1].axis("off")

            # Plot attacked image as the third image
            axes[2].imshow(attacked_img)
            axes[2].set_title(
                f'Attacked\nPredicted: {attack["final_pred"]}\nTrue: {attack["true_label"]}'
            )
            axes[2].axis("off")
        else:
            # Plot attacked image as the second image
            axes[1].imshow(attacked_img)
            axes[1].set_title(
                f'Attacked\nPredicted: {attack["final_pred"]}\nTrue: {attack["true_label"]}'
            )
            axes[1].axis("off")

        # Add a main title with attack details
        plt.suptitle(f'Adversarial Attack Example\nImg id: {attack["img_idx"]}')

        # Save the figure
        filename = f'attack_img_{attack["img_idx"]}.png'
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, bbox_inches="tight", dpi=300)
        plt.close()

        print(f"Saved image pair to {filepath}")

        # Optional: limit the number of saved images
        if i >= 19:  # Save maximum 20 examples
            print("Reached maximum number of saved examples (20)")
            break


def run_adversarial_attack(
    model,
    classifier,
    dataloader,
    attack_params,
    save_file="",
    attack_fn=pgd_l2_adv,
    noise_magnitude=0.05,
    noise_type="linf",
    noise_seed=None,
    continue_after_success=False,
    verbose=False,
):
    """
    Run an adversarial attack on a model using the specified attack function.

    Parameters:
        model : torch.nn.Module
            The model to attack.
        classifier : torch.Tensor
            The classifier vectors for zero-shot classification.
        dataloader : torch.utils.data.DataLoader
            The dataloader containing the data to attack.
        attack_params : dict
            Parameters for the attack function.
        save_file : str, optional
            Path to save attack results (default is "").
        attack_fn : callable, optional
            The attack function to use (default is pgd_l2_adv).
        noise_magnitude : float, optional
            Magnitude of noise to add (default is 0.05).
        noise_type : str, optional
            Type of noise to add (default is "linf").
        noise_seed : int, optional
            Random seed for noise generation (default is None).
        continue_after_success : bool, optional
            If True, continue PGD iterations even after finding successful attacks
            to potentially find better attacks within the budget (default is False).

    Returns:
        dict
            Dictionary containing attack metrics:
            - clean_accuracy: Accuracy on clean data
            - robust_accuracy: Accuracy after attack
            - attack_success_rate: Rate of successful attacks
            - total_samples: Total number of samples
    """
    # Add continue_after_success to attack_params
    attack_params = attack_params.copy()
    attack_params["continue_after_success"] = continue_after_success

    # Capture all returned values from calculate_accuracy
    clean_acc, robust_acc, attack_success_rate, total_samples = calculate_accuracy(
        model,
        classifier,
        dataloader,
        DEVICE,
        attack_fn=attack_fn,
        filepath=save_file,
        noise_magnitude=noise_magnitude,
        noise_type=noise_type,
        noise_seed=noise_seed,
        verbose=verbose,
        **attack_params,
    )
    # Return the attack success rate along with other potentially useful metrics
    return {
        "clean_accuracy": clean_acc,
        "robust_accuracy": robust_acc,
        "attack_success_rate": attack_success_rate,
        "total_samples": total_samples,
    }
