import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch.profiler import record_function
from tqdm import tqdm

from adversarial_superposition.constants import BASE_DIR, DEVICE


def convert_to_2d(img):
    weights = np.array([0.2989, 0.5870, 0.1140])
    weights = weights.reshape(3, 1, 1)
    return np.sum(img * weights, axis=0)


def norm(Z):
    """Compute norms over all but the first dimension"""

    return torch.norm(Z.view(Z.shape[0], -1), dim=1)


def pgd_linf_adv(
    model,
    X,
    y,
    target_classes,
    alpha,
    num_iter,
    epsilon=0,
    zero_shot_classifier=None,
    verbose=False,
    loss_fn=None,
    instance_idx=None,
    find_worst_case: bool = False,
):
    """
    Perform a Projected Gradient Descent (PGD) L-infinity adversarial attack on the input.

    This function iteratively applies the PGD attack for a specified number of
    iterations to generate adversarial examples that maximize the model's loss
    while constraining the L-infinity norm of the perturbation.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for each PGD step.
        num_iter : int
            The number of PGD iterations to perform.
        epsilon : float, optional
            The maximum allowed L-infinity norm of the perturbation (default is 0).
        example : bool, optional
            If True, print the final loss after the attack (default is False).
        find_worst_case : bool, optional
            If True, continue attacking for all iterations even after finding an
            adversary, returning the delta after num_iter steps (default is False).
            If False, stop early for samples once an adversary is found.

    Returns:
        delta : torch.Tensor
            The final perturbation tensor after all PGD iterations.
        success_mask : torch.Tensor
            Boolean tensor indicating which samples were successfully attacked based on the final delta.
        orig_success : torch.Tensor
             Boolean tensor indicating if the original samples were correctly classified.
    """
    delta_initial = torch.zeros_like(
        X, requires_grad=True
    )  # Initial delta is always zero
    current_delta = delta_initial.clone()  # Delta used and updated within the loop
    batch_size = X.shape[0]
    loss = torch.tensor(0.0)  # Initialize loss tracker

    with torch.no_grad():
        orig_logits = model(X, instance_idx=instance_idx)
        orig_predictions = orig_logits.argmax(dim=-1)
        orig_success = orig_predictions == y

    # --- Early Stopping Setup (Only used if find_worst_case=False) ---
    early_stop_success_mask = torch.zeros(batch_size, dtype=torch.bool, device=X.device)
    # Stores the delta that *first* caused a successful attack for early stopping.
    early_stop_best_deltas = torch.zeros_like(X)

    # Variable to store the number of iterations run
    iterations_run = num_iter

    for t in range(num_iter):
        if find_worst_case:
            # Process all samples every iteration
            active_mask = torch.ones(batch_size, dtype=torch.bool, device=X.device)
            # Input for the step is the delta from the previous iteration
            step_delta_input = current_delta
        else:
            # Standard PGD: Process only samples not yet successfully attacked
            active_mask = ~early_stop_success_mask
            # Break if all samples are already successfully attacked
            if not active_mask.any():
                if verbose:
                    print(
                        f"All samples successfully attacked (early stopping). Iter {t}"
                    )
                iterations_run = t  # Record actual iterations run
                break
            # Input for the step is the delta for *active* samples
            step_delta_input = current_delta[active_mask]

        # Perform one PGD step only on active samples
        # The output delta `step_delta_output` corresponds only to the active samples
        step_delta_output, current_loss = single_pgd_step_linf(
            model,
            X[active_mask],
            y[active_mask],
            None if target_classes is None else target_classes[active_mask],
            alpha,
            epsilon,
            step_delta_input,  # Pass the correct delta input based on mode
            zero_shot_classifier=zero_shot_classifier,
            loss_fn=loss_fn,
            instance_idx=instance_idx,
        )

        # --- Update delta and potentially check for early stopping ---
        with torch.no_grad():
            if find_worst_case:
                # Always update the delta for all samples
                # Since step was only run on active_mask (all samples), output has full batch size
                current_delta = step_delta_output
            else:
                # Update the main 'current_delta' only for the active samples
                current_delta[active_mask] = step_delta_output

                # Check which *newly processed* samples became successful *in this step*
                logits = model(
                    X[active_mask] + step_delta_output, instance_idx=instance_idx
                )
                if zero_shot_classifier is not None:
                    logits = 100.0 * logits @ zero_shot_classifier

                if y is None:  # Should not happen if y[active_mask] was used
                    step_success = torch.zeros_like(y[active_mask], dtype=torch.bool)
                elif target_classes is not None:  # Targeted
                    predictions = logits.argmax(dim=-1)
                    step_success = predictions == target_classes[active_mask]
                else:  # Untargeted
                    predictions = logits.argmax(dim=-1)
                    step_success = predictions != y[active_mask]

                # Identify samples that became successful *in this step* among the active ones
                newly_successful_active_mask = step_success

                # Create a full-batch mask for samples that just became successful
                full_newly_successful_mask = torch.zeros(
                    batch_size, dtype=torch.bool, device=X.device
                )
                full_newly_successful_mask[active_mask] = newly_successful_active_mask

                # Store the delta that *first* caused success for these newly successful samples
                # Ensure we only update if it hasn't been marked successful before
                update_best_delta_mask = full_newly_successful_mask & (
                    ~early_stop_success_mask
                )
                # We need to index step_delta_output correctly using newly_successful_active_mask
                early_stop_best_deltas[update_best_delta_mask] = step_delta_output[
                    newly_successful_active_mask
                ]

                # Update the overall early stopping success mask
                early_stop_success_mask[full_newly_successful_mask] = True

        if verbose and t % 30 == 0:
            # For logging, check success based on the delta that will be returned
            with torch.no_grad():
                log_check_delta = current_delta  # Use current_delta as it evolves
                # If early stopping, we could also log based on early_stop_best_deltas, but current_delta shows progress
                log_logits = model(X + log_check_delta, instance_idx=instance_idx)
                if zero_shot_classifier is not None:
                    log_logits = 100.0 * log_logits @ zero_shot_classifier

                if target_classes is not None:
                    log_preds = log_logits.argmax(dim=-1)
                    log_success_mask = log_preds == target_classes
                else:  # Untargeted
                    log_preds = log_logits.argmax(dim=-1)
                    log_success_mask = log_preds != y
                # Correct indentation for the print statement:
                print(
                    f"Iteration {t}: {log_success_mask.sum().item()}/{batch_size} successful attacks (using current delta state)"
                )

    final_delta = early_stop_best_deltas if not find_worst_case else current_delta

    # Final check of success based on the chosen `final_delta`
    with torch.no_grad():
        final_logits = model(X + final_delta, instance_idx=instance_idx)
        if zero_shot_classifier is not None:
            final_logits = 100.0 * final_logits @ zero_shot_classifier

        if target_classes is not None:  # Targeted
            final_predictions = final_logits.argmax(dim=-1)
            final_success_mask = final_predictions == target_classes
        else:  # Untargeted
            final_predictions = final_logits.argmax(dim=-1)
            final_success_mask = final_predictions != y

    if verbose:
        print(
            f"Attack completed after {iterations_run} iterations with {final_success_mask.sum().item()}/{batch_size} successful attacks"
        )
        print(f"Mode: {'Run Full Iterations' if find_worst_case else 'Early Stopping'}")
        print(f"Max element value in final delta: {final_delta.max().item():.4f}")
        print(f"Min element value in final delta: {final_delta.min().item():.4f}")

    return final_delta, final_success_mask, orig_success


def single_pgd_step_linf(
    model,
    X,
    y,
    target_classes,
    alpha,
    epsilon,
    delta,
    zero_shot_classifier=None,
    loss_fn=None,
    instance_idx=None,
):
    """
    Perform a single step of a Projected Gradient Descent (PGD) L-infinity adversarial attack.

    This function calculates the gradient of the loss with respect to the input
    perturbation, updates the perturbation using the sign of this gradient (FGSM-style),
    and then projects it back onto the L-infinity epsilon-ball.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor. (Shape: [batch_size, ...])
        y : torch.Tensor
            The true labels corresponding to X. (Shape: [batch_size])
        target_classes : torch.Tensor or None
            Target labels for targeted attacks. (Shape: [batch_size])
        alpha : float
            The step size (learning rate) for the PGD attack.
        epsilon : float
            The maximum allowed perturbation (perturbation budget).
        delta : torch.Tensor
            The current perturbation tensor for the *active* samples. (Shape: [active_batch_size, ...])
        zero_shot_classifier : torch.Tensor or None
            Optional classifier weights for zero-shot classification.
        loss_fn : function or None
            Loss function to use. Defaults to F.cross_entropy.
        instance_idx : int or None
            Optional instance index for models that require it.


    Returns:
        delta : torch.Tensor
            The updated perturbation tensor after one PGD step for the active samples. (Shape: [active_batch_size, ...])
        loss : torch.Tensor
            The computed loss values for each active sample. (Shape: [active_batch_size])
    """
    if loss_fn is None:
        loss_fn = F.cross_entropy

    # Ensure delta requires grad for the current step
    delta = delta.detach().clone().requires_grad_()

    # Determine the labels to use for loss calculation
    effective_y = target_classes if target_classes is not None else y

    with record_function("Forward pass"):
        # Enable gradients for the model forward pass
        with torch.set_grad_enabled(True):
            # The input to the model is X + delta
            perturbed_X = (
                X + delta
            ).requires_grad_()  # Ensure perturbed input requires grad if needed downstream
            logits = model(perturbed_X, instance_idx=instance_idx)

            if zero_shot_classifier is not None:
                logits = 100.0 * logits @ zero_shot_classifier

            # Calculate loss - use reduction='none' to get per-sample loss
            if effective_y is None:
                # Handle cases like autoencoders if needed, requires specific loss_fn
                raise ValueError(
                    "Loss calculation requires labels (y) or target_classes."
                )
            # loss = loss_fn(logits, X, reduction="none") # Example for reconstruction
            else:
                # For classification:
                # Targeted attack: maximize loss wrt original label -> minimize loss wrt target label -> negate loss
                # Untargeted attack: maximize loss wrt original label
                loss = loss_fn(
                    out=logits,
                    labels=effective_y,
                    instance_idx=instance_idx,
                    reduction="none",
                )  # Pass reduction="none"
                if target_classes is not None:
                    loss = -loss  # Negate loss for targeted attacks

            # Calculate gradients w.r.t. the input perturbation delta
            # Summing the per-sample losses before backward is standard practice for PGD
            grad = torch.autograd.grad(loss.sum(), delta)[0]

    with record_function("Linf Attack Step"):
        with torch.no_grad():
            # PGD step: update delta based on the sign of the gradient
            # For L-infinity, the step is alpha * sign(gradient)
            # Targeted attacks move away from the target (using negative loss, so gradient points away) -> add alpha * sign(grad)
            # Untargeted attacks move away from the original label (using positive loss, so gradient points away) -> add alpha * sign(grad)
            # Thus, the update rule is the same for both if loss is defined as above.
            z = delta + alpha * grad.sign()

            # Project delta back onto the L-infinity ball defined by epsilon
            # Clamp each element of the perturbation to [-epsilon, epsilon]
            delta = torch.clamp(z, -epsilon, epsilon)

    # Return the updated delta and the per-sample loss
    # Note: delta returned here corresponds to the input delta shape: [active_batch_size, ...]
    return delta, loss


def pgd_l2_adv(
    model,
    X,
    y,
    target_classes,
    alpha,
    num_iter,
    epsilon=0,
    zero_shot_classifier=None,
    verbose=False,
    loss_fn=None,
    instance_idx=None,
    find_worst_case: bool = False,
):
    """
    Perform a Projected Gradient Descent (PGD) L2 adversarial attack on the input.

    This function iteratively applies the PGD attack for a specified number of
    iterations to generate adversarial examples that maximize the model's loss
    while constraining the L2 norm of the perturbation.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for each PGD step.
        num_iter : int
            The number of PGD iterations to perform.
        epsilon : float, optional
            The maximum allowed L2 norm of the perturbation (default is 0, which
            means no constraint).
        example : bool, optional
            If True, print the final loss after the attack (default is False).
        find_worst_case : bool, optional
            If True, continue attacking for all iterations even after finding an
            adversary, returning the delta after num_iter steps (default is False).
            If False, stop early for samples once an adversary is found.

    Returns:
        delta : torch.Tensor
            The final perturbation tensor after all PGD iterations.
        success_mask : torch.Tensor
            Boolean tensor indicating which samples were successfully attacked based on the final delta.
        orig_success : torch.Tensor
             Boolean tensor indicating if the original samples were correctly classified.
    """
    delta_initial = torch.zeros_like(X, requires_grad=True)
    current_delta = delta_initial.clone()
    batch_size = X.shape[0]
    loss = torch.tensor(0.0)  # Initialize loss tracker

    with torch.no_grad():
        orig_logits = model(X, instance_idx=instance_idx)
        orig_predictions = orig_logits.argmax(dim=-1)
        orig_success = orig_predictions == y

    # --- Early Stopping Setup (Only used if find_worst_case=False) ---
    early_stop_success_mask = torch.zeros(batch_size, dtype=torch.bool, device=X.device)
    early_stop_best_deltas = torch.zeros_like(X)
    iterations_run = num_iter

    for t in range(num_iter):
        if find_worst_case:
            active_mask = torch.ones(batch_size, dtype=torch.bool, device=X.device)
            step_delta_input = current_delta
        else:
            active_mask = ~early_stop_success_mask
            if not active_mask.any():
                if verbose:
                    print(
                        f"All samples successfully attacked (early stopping). Iter {t}"
                    )
                iterations_run = t
                break
            step_delta_input = current_delta[active_mask]

        # Perform one PGD step only on active samples
        step_delta_output, current_loss = single_pgd_step_adv(  # Use L2 step function
            model,
            X[active_mask],
            y[active_mask],
            None if target_classes is None else target_classes[active_mask],
            alpha,
            epsilon,
            step_delta_input,
            zero_shot_classifier=zero_shot_classifier,
            loss_fn=loss_fn,
            instance_idx=instance_idx,
        )

        # --- Update delta and potentially check for early stopping ---
        with torch.no_grad():
            if find_worst_case:
                current_delta = step_delta_output
            else:
                current_delta[active_mask] = step_delta_output

                logits = model(
                    X[active_mask] + step_delta_output, instance_idx=instance_idx
                )
                if zero_shot_classifier is not None:
                    logits = 100.0 * logits @ zero_shot_classifier

                if y is None:
                    step_success = torch.zeros_like(y[active_mask], dtype=torch.bool)
                elif target_classes is not None:  # Targeted
                    predictions = logits.argmax(dim=-1)
                    step_success = predictions == target_classes[active_mask]
                else:  # Untargeted
                    predictions = logits.argmax(dim=-1)
                    step_success = predictions != y[active_mask]

                newly_successful_active_mask = step_success
                full_newly_successful_mask = torch.zeros(
                    batch_size, dtype=torch.bool, device=X.device
                )
                full_newly_successful_mask[active_mask] = newly_successful_active_mask

                update_best_delta_mask = full_newly_successful_mask & (
                    ~early_stop_success_mask
                )
                early_stop_best_deltas[update_best_delta_mask] = step_delta_output[
                    newly_successful_active_mask
                ]
                early_stop_success_mask[full_newly_successful_mask] = True

        if verbose and t % 30 == 0:
            with torch.no_grad():
                log_check_delta = current_delta
                log_logits = model(X + log_check_delta, instance_idx=instance_idx)
                if zero_shot_classifier is not None:
                    log_logits = 100.0 * log_logits @ zero_shot_classifier

                if target_classes is not None:
                    log_preds = log_logits.argmax(dim=-1)
                    log_success_mask = log_preds == target_classes
                else:  # Untargeted
                    log_preds = log_logits.argmax(dim=-1)
                    log_success_mask = log_preds != y
                print(
                    f"Iteration {t}: {log_success_mask.sum().item()}/{batch_size} successful attacks (using current delta state)"
                )

    # --- Final setup ---
    final_delta = early_stop_best_deltas if not find_worst_case else current_delta

    # Final check of success based on the chosen final delta
    with torch.no_grad():
        final_logits = model(X + final_delta, instance_idx=instance_idx)
        if zero_shot_classifier is not None:
            final_logits = 100.0 * final_logits @ zero_shot_classifier

        if target_classes is not None:  # Targeted
            final_predictions = final_logits.argmax(dim=-1)
            final_success_mask = final_predictions == target_classes
        else:  # Untargeted
            final_predictions = final_logits.argmax(dim=-1)
            final_success_mask = final_predictions != y

    if verbose:
        print(
            f"Attack completed after {iterations_run} iterations with {final_success_mask.sum().item()}/{batch_size} successful attacks"
        )
        print(f"Mode: {'Run Full Iterations' if find_worst_case else 'Early Stopping'}")
        final_delta_norms = norm(final_delta)
        print(f"Max L2 norm of final delta: {final_delta_norms.max().item():.4f}")
        print(f"Mean L2 norm of final delta: {final_delta_norms.mean().item():.4f}")

    return final_delta, final_success_mask, orig_success


def single_pgd_step_adv(
    model,
    X,
    y,
    target_classes,
    alpha,
    epsilon,
    delta,
    zero_shot_classifier=None,
    loss_fn=None,
    instance_idx=None,
):
    """
    Perform a single step of a Projected Gradient Descent (PGD) adversarial attack.

    This function calculates the gradient of the loss with respect to the input
    perturbation, updates the perturbation in the direction of this gradient,
    and then projects it back onto the epsilon-ball to ensure it remains within
    the allowed perturbation budget.

    1. Gradient Computation:
       - Enable gradient computation for delta and X.
       - Compute model outputs and cross-entropy loss.
       - Calculate the gradient of the loss with respect to delta.

    2. Gradient Step:
       - Normalize the gradient to prevent exploding gradients.
       - Update delta by taking a step in the direction of the normalized gradient,
         scaled by alpha.

    3. Projection Step:
       - Project the updated perturbation back onto the epsilon-ball.
       - This ensures the L-infinity norm of the perturbation doesn't exceed epsilon.

    4. Return Values:
       - Return the updated delta (perturbation) and the computed loss.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        target_classes : torch.Tensor or None
            Target labels for targeted attacks.
        alpha : float
            The step size (learning rate) for the PGD attack.
        epsilon : float
            The maximum allowed perturbation (perturbation budget).
        delta : torch.Tensor
            The current perturbation tensor for the *active* samples.
        zero_shot_classifier : torch.Tensor or None
            Optional classifier weights for zero-shot classification.
        loss_fn : function or None
            Loss function to use. Defaults to F.cross_entropy.
        instance_idx : int or None
            Optional instance index for models that require it.


    Returns:
        delta : torch.Tensor
            The updated perturbation tensor after one PGD step for the active samples.
        loss : torch.Tensor
            The computed loss values for each active sample.

    """
    if loss_fn is None:
        loss_fn = F.cross_entropy

    delta = delta.detach().clone().requires_grad_()

    effective_y = target_classes if target_classes is not None else y

    with record_function("Forward pass"):
        with torch.set_grad_enabled(True):
            delta.requires_grad_()
            perturbed_X = (X + delta).requires_grad_()
            logits = model(perturbed_X, instance_idx=instance_idx)

            if zero_shot_classifier is not None:
                logits = 100.0 * logits @ zero_shot_classifier

            if effective_y is None:
                raise ValueError(
                    "Loss calculation requires labels (y) or target_classes."
                )
            else:
                loss = loss_fn(
                    out=logits,
                    labels=effective_y,
                    instance_idx=instance_idx,
                    reduction="none",
                )
                if target_classes is not None:
                    loss = -loss  # Negate loss for targeted attacks

            grad = torch.autograd.grad(loss.sum(), delta)[0]

    with record_function("L2 Attack Step"):
        with torch.no_grad():
            # L2 step update: delta + alpha * (gradient / norm(gradient))
            # Normalize gradient across all dimensions except batch
            grad_norm = norm(grad).view(
                -1, *([1] * (grad.dim() - 1))
            )  # Reshape norm for broadcasting
            # Avoid division by zero
            normalized_grad = grad / (grad_norm + 1e-10)

            # Update step (same direction logic as Linf based on loss definition)
            z = delta + alpha * normalized_grad

            # Project delta onto the L2 ball of radius epsilon
            z_norm = norm(z).view(-1, *([1] * (z.dim() - 1)))
            # Scaling factor: min(1, epsilon / norm(z))
            scale = torch.clamp(epsilon / (z_norm + 1e-10), max=1.0)
            delta = z * scale

    return delta, loss


def calculate_accuracy(
    net,
    classifier,
    data_loader,
    device,
    N=2000,
    batch_size=50,
    attack_fn=None,
    top_k=1,
    verbose=False,
    **kwargs,
):
    net.eval()
    correct = 0
    total = 0
    successful_attacks = []
    all_attacks = []

    target = kwargs.pop("target_classes", None)
    if target is not None:
        target_classes = target.to(DEVICE)

    for images, labels, img_indices in tqdm(data_loader):
        images = images.clone()
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        if target is None:
            target_class = labels

        with torch.no_grad():
            initial_logits = net(images)
            # print_memory_stats()
            initial_logits = 100.0 * initial_logits @ classifier
            initial_preds = initial_logits.topk(top_k)[1].t()[0]
            initial_correct = initial_preds.eq(labels)

        torch.cuda.empty_cache()  # Add this

        if attack_fn:
            delta = attack_fn(net, images, labels, target_class, **kwargs)
            attack_images = images + delta
        else:
            attack_images = images

        if verbose:
            display_img = convert_to_2d(attack_images.detach().cpu().numpy()[0, :])

        with torch.no_grad():
            # print_memory_stats()
            logits = net(attack_images.to(device))
            logits = 100.0 * logits @ classifier

            # NEW: Store final predictions separately
            final_preds = logits.topk(top_k)[1].t()[0]
            final_correct = final_preds.eq(labels)

            c = 0
            # Track attacks
            for idx in range(len(labels)):
                a = {
                    "img_idx": img_indices[idx],
                    "true_label": labels[idx].item(),
                    "initial_pred": initial_preds[idx].item(),
                    "target_label": target_class[idx].item(),
                    "final_pred": final_preds[idx].item(),
                    # 'original_image': images[idx].detach().cpu().numpy(),
                    # 'attacked_image': attack_images[idx].detach().cpu().numpy(),
                }
                if initial_correct[idx] and final_preds[idx] == target_class[idx]:
                    # print(f"Image initially CORRECT and finally of the class TARGETED attack: {img_indices[idx]}")
                    # Successful targetted attack: Image initially CORRECT and finally of the class TARGETED attack
                    successful_attacks.append(a)
                    all_attacks.append(a)
                elif initial_correct[idx] and not final_correct[idx]:
                    # print(f"Image initially CORRECT and finally INCORRECT: {img_indices[idx]}")
                    # Successful attack: Image initially CORRECT and finally INCORRECT
                    successful_attacks.append(a)
                    all_attacks.append(a)
                elif initial_correct[idx] and final_correct[idx]:
                    # Model correct: Image initially CORRECT and finally CORRECT
                    print(
                        f"Image initially CORRECT and finally CORRECT: {img_indices[idx]}"
                    )
                    all_attacks.append(a)
                elif not initial_correct[idx]:
                    # Model incorrect
                    print(f"Image initially INCORRECT: {img_indices[idx]}")
                    all_attacks.append(a)

            correct += final_correct.sum().item()
            total += len(labels)

        torch.cuda.empty_cache()

    return correct / total, successful_attacks, all_attacks


def save_attack_examples(successful_attacks, save_dir):
    """
    Save the attack examples to individual files without recomputing the attacks.
    """
    # Create directory for saving images if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    for i, attack in enumerate(successful_attacks):
        # Create a new figure for each attack
        fig, axes = plt.subplots(1, 2, figsize=(10, 4))

        # Get the stored images
        orig_img = attack["original_image"]
        attacked_img = attack["attacked_image"]

        # Handle channel ordering
        if len(orig_img.shape) == 3:  # Handle both RGB and grayscale
            orig_img = orig_img.transpose(1, 2, 0)  # Change from CxHxW to HxWxC
            attacked_img = attacked_img.transpose(1, 2, 0)

        # Plot original image
        axes[0].imshow(orig_img)
        axes[0].set_title(
            f'Original\nPredicted: {attack["initial_pred"]}\nTrue: {attack["true_label"]}'
        )
        axes[0].axis("off")

        # Plot attacked image
        axes[1].imshow(attacked_img)
        axes[1].set_title(
            f'Attacked\nPredicted: {attack["final_pred"]}\nTrue: {attack["true_label"]}'
        )
        axes[1].axis("off")

        # Add a main title with attack details
        plt.suptitle(f'Adversarial Attack Example\nImg id: {attack["img_idx"]}')

        # Save the figure
        filename = f'attack_img_{attack["img_idx"]}.png'
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, bbox_inches="tight", dpi=300)
        plt.close()

        print(f"Saved image pair to {filepath}")

        # Optional: limit the number of saved images
        if i >= 19:  # Save maximum 20 examples
            print("Reached maximum number of saved examples (20)")
            break


def run_adversarial_attack(
    model, classifier, dataloader, attack_params, save_file="", attack_fn=pgd_l2_adv
):
    accuracy, successful_attacks, all_attacks = calculate_accuracy(
        model, classifier, dataloader, DEVICE, attack_fn=attack_fn, **attack_params
    )
    if save_file:
        print("Saving attacks")
        torch.save(
            successful_attacks,
            BASE_DIR
            / f"experiments/imagenet/random/{save_file}_successful_attacks.pth",
        )
        torch.save(
            all_attacks,
            BASE_DIR / f"experiments/imagenet/random/{save_file}_all_attacks.pth",
        )

    return accuracy, successful_attacks
