import os

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.profiler import record_function
from tqdm import tqdm

from adversarial_superposition.constants import BASE_DIR, DEVICE


def norm(Z):
    """Compute norms over all but the first dimension"""

    return torch.norm(Z.view(Z.shape[0], -1), dim=1)


def pgd_linf_adv(
    model,
    X,
    y,
    target_classes,
    alpha,
    num_iter,
    epsilon=0,
    zero_shot_classifier=None,
    verbose=False,
    loss_fn=None,
    input_size_one_digit=None,
):
    """
    Perform a Projected Gradient Descent (PGD) L-infinity adversarial attack on the input.

    This function iteratively applies the PGD attack for a specified number of
    iterations to generate adversarial examples that maximize the model's loss
    while constraining the L-infinity norm of the perturbation.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for each PGD step.
        num_iter : int
            The number of PGD iterations to perform.
        epsilon : float, optional
            The maximum allowed L-infinity norm of the perturbation (default is 0).
        example : bool, optional
            If True, print the final loss after the attack (default is False).

    Returns:
        delta : torch.Tensor
            The final perturbation tensor after all PGD iterations.
        success_mask : torch.Tensor
            Boolean tensor indicating which samples were successfully attacked.
    """
    delta = torch.zeros_like(X, requires_grad=True, device=X.device)
    loss = 0
    batch_size = X.shape[0]

    with torch.no_grad():
        orig_logits = model(X)
        orig_predictions = orig_logits.argmax(dim=-1)
        orig_success = orig_predictions == y

    success_mask = torch.zeros(batch_size, dtype=torch.bool, device=X.device)
    best_deltas = torch.zeros_like(X, device=X.device)

    for t in range(num_iter):
        if success_mask.all():  # All samples successfully attacked
            break

        current_delta, loss = single_pgd_step_linf(
            model,
            X[~success_mask],
            # X[~success_mask].reshape(-1, *X.shape[1:]),
            y[~success_mask],
            # y[~success_mask] if y else None,
            # y[~success_mask].reshape(-1, *y.shape[1:]),
            None if target_classes is None else target_classes[~success_mask],
            alpha,
            epsilon,
            delta[~success_mask],
            # delta[~success_mask].reshape(-1, *delta.shape[1:]),
            zero_shot_classifier=zero_shot_classifier,
            loss_fn=loss_fn,
            input_size_one_digit=input_size_one_digit,
        )

        # Check which samples are now successfully attacked
        with torch.no_grad():
            logits = model(X[~success_mask] + current_delta)
            # logits = model(X[~success_mask].reshape(-1, *X.shape[1:]) + current_delta)

            if zero_shot_classifier is not None:
                logits = 100.0 * logits @ zero_shot_classifier

            if y == None:
                current_success = success_mask.clone()
            elif target_classes is not None:
                # Targeted attack: success if prediction matches target
                predictions = logits.argmax(dim=-1)
                current_success = predictions == target_classes[~success_mask]
            else:
                # Untargeted attack: success if prediction differs from true label
                predictions = logits.argmax(dim=-1)
                current_success = predictions != y[~success_mask]
                # current_success = predictions != y[~success_mask].reshape(-1, *y.shape[1:])

        # Update tracking variables for successful attacks
        best_deltas[~success_mask] = current_delta
        # best_deltas = torch.where(~success_mask.unsqueeze(-1).expand(-1, -1, 6), current_delta, best_deltas)
        success_mask[~success_mask] = current_success
        # success_mask = torch.where(~success_mask, current_success, success_mask)
        delta = best_deltas

        if verbose:
            print(f"Iteration {t}: {success_mask.sum().item()}/{batch_size}")

    if verbose:
        print(
            f"Attack completed with {success_mask.sum().item()}/{batch_size} successful attacks"
        )
        print(f"The max size of the final delta is: {delta.max().item()}")
        print(f"The min size of the final delta is: {delta.min().item()}")
        print(
            f"Stopped after {t + 1} iterations, final Cross Entropy {loss.mean().item()}"
        )

    # wandb.log({"number_of_attack_iters": t + 1})

    return delta, success_mask, orig_success


def single_pgd_step_linf(
    model,
    X,
    y,
    target_classes,
    alpha,
    epsilon,
    delta,
    zero_shot_classifier=None,
    loss_fn=None,
    input_size_one_digit=None,
):
    """
    Perform a single step of a Projected Gradient Descent (PGD) L-infinity adversarial attack.

    This function calculates the gradient of the loss with respect to the input
    perturbation, updates the perturbation using the sign of this gradient (FGSM-style),
    and then projects it back onto the L-infinity epsilon-ball.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for the PGD attack.
        epsilon : float
            The maximum allowed perturbation (perturbation budget).
        delta : torch.Tensor
            The current perturbation tensor.

    Returns:
        delta : torch.Tensor
            The updated perturbation tensor after one PGD step.
        loss : torch.Tensor
            The computed loss values for each sample.
    """
    if loss_fn is None:
        loss_fn = F.cross_entropy

    delta = delta.detach().clone().requires_grad_()

    if target_classes is not None:
        y = target_classes

    with record_function("Forward pass"):
        with torch.set_grad_enabled(True):
            delta.requires_grad_()

            logits = model((X + delta).requires_grad_())

            if zero_shot_classifier is not None:
                logits = 100.0 * logits @ zero_shot_classifier

            if y == None:
                loss = loss_fn(logits, X)
            elif target_classes is not None:
                loss = -loss_fn(logits, target_classes, reduction="none")
            else:
                loss = loss_fn(logits, y, reduction="none")

            grad = torch.autograd.grad(
                loss, delta, grad_outputs=torch.ones_like(loss, device=loss.device)
            )[0]

    if input_size_one_digit is not None:
        if grad.shape[1] > input_size_one_digit:
            grad[:, input_size_one_digit:] = 0.0

    with record_function("Attack"):
        with torch.no_grad():
            # For L-infinity, we use the sign of the gradient instead of L2 normalization
            if target_classes is not None:
                # For targeted attacks, move opposite to the gradient direction
                z = delta - alpha * grad.sign()
            else:
                # For untargeted attacks, move in the gradient direction
                z = delta + alpha * grad.sign()

            # Project onto L-infinity ball by clipping
            delta = torch.clamp(z, -epsilon, epsilon)

    return delta, loss


def pgd_l2_adv(
    model,
    X,
    y,
    target_classes,
    alpha,
    num_iter,
    epsilon=0,
    zero_shot_classifier=None,
    verbose=False,
    loss_fn=None,
    input_size_one_digit=None,
):
    """
    Perform a Projected Gradient Descent (PGD) L2 adversarial attack on the input.

    This function iteratively applies the PGD attack for a specified number of
    iterations to generate adversarial examples that maximize the model's loss
    while constraining the L2 norm of the perturbation.

    Components:
    -----------
    1. Initialization: f
       - Initialize delta as a zero tensor with the same shape as X.
       - Set requires_grad to True for delta to enable gradient computation.

    2. PGD Iteration:
       - Perform 'num_iter' iterations of the PGD attack.
       - In each iteration, call single_pgd_step_adv to update delta and compute loss.

    4. Return:
       - Return the final perturbation delta.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for each PGD step.
        num_iter : int
            The number of PGD iterations to perform.
        epsilon : float, optional
            The maximum allowed L2 norm of the perturbation (default is 0, which
            means no constraint).
        example : bool, optional
            If True, print the final loss after the attack (default is False).

    Returns:
        delta : torch.Tensor
            The final perturbation tensor after all PGD iterations.

    """
    delta = torch.zeros_like(X, requires_grad=True)
    loss = 0
    batch_size = X.shape[0]

    # Keep track of which samples have been successfully attacked
    success_mask = torch.zeros(batch_size, dtype=torch.bool, device=X.device)
    best_deltas = torch.zeros_like(X)  # Store best perturbations found

    for t in range(num_iter):
        if success_mask.all():  # All samples successfully attacked
            break

        current_delta, loss = single_pgd_step_adv(
            model,
            X[~success_mask],
            y[~success_mask],
            None if target_classes is None else target_classes[~success_mask],
            alpha,
            epsilon,
            delta[~success_mask],
            zero_shot_classifier=zero_shot_classifier,
            loss_fn=loss_fn,
            input_size_one_digit=input_size_one_digit,
        )

        # Check which samples are now successfully attacked
        with torch.no_grad():
            logits = model(X[~success_mask] + current_delta)

            if zero_shot_classifier is not None:
                logits = 100.0 * logits @ zero_shot_classifier

            if target_classes is not None:
                # Targeted attack: success if prediction matches target
                predictions = logits.argmax(dim=-1)
                current_success = predictions == target_classes[~success_mask]
            else:
                # Untargeted attack: success if prediction differs from true label
                predictions = logits.argmax(dim=-1)
                current_success = predictions != y[~success_mask]

        # Update tracking variables for successful attacks
        best_deltas[~success_mask] = current_delta
        success_mask[~success_mask] = current_success
        delta = best_deltas

        if verbose:
            print(f"Iteration {t}: {success_mask.sum().item()}/{batch_size}")

    if verbose:
        print(
            f"Attack completed with {success_mask.sum().item()}/{batch_size} successful attacks"
        )
        print(f"The max size of the final delta is: {delta.max().item()}")
        print(f"The min size of the final delta is: {delta.min().item()}")

    if verbose:
        print(
            f"Stopped after {t + 1} iterations, final Cross Entropy {loss.mean().item()}"
        )

    # wandb.log({"number_of_attack_iters": t + 1})

    return delta, success_mask


def single_pgd_step_adv(
    model,
    X,
    y,
    target_classes,
    alpha,
    epsilon,
    delta,
    zero_shot_classifier=None,
    loss_fn=None,
    input_size_one_digit=None,
):
    """
    Perform a single step of a Projected Gradient Descent (PGD) adversarial attack.

    This function calculates the gradient of the loss with respect to the input
    perturbation, updates the perturbation in the direction of this gradient,
    and then projects it back onto the epsilon-ball to ensure it remains within
    the allowed perturbation budget.

    1. Gradient Computation:
       - Enable gradient computation for delta and X.
       - Compute model outputs and cross-entropy loss.
       - Calculate the gradient of the loss with respect to delta.

    2. Gradient Step:
       - Normalize the gradient to prevent exploding gradients.
       - Update delta by taking a step in the direction of the normalized gradient,
         scaled by alpha.

    3. Projection Step:
       - Project the updated perturbation back onto the epsilon-ball.
       - This ensures the L-infinity norm of the perturbation doesn't exceed epsilon.

    4. Return Values:
       - Return the updated delta (perturbation) and the computed loss.

    Parameters:
        model : torch.nn.Module
            The neural network model being attacked.
        X : torch.Tensor
            The input data tensor.
        y : torch.Tensor
            The true labels corresponding to X.
        alpha : float
            The step size (learning rate) for the PGD attack.
        epsilon : float
            The maximum allowed perturbation (perturbation budget).
        delta : torch.Tensor
            The current perturbation tensor.

    Returns:
        delta : torch.Tensor
            The updated perturbation tensor after one PGD step.
        loss : torch.Tensor
            The computed loss values for each sample.

    """
    if loss_fn is None:
        loss_fn = F.cross_entropy

    delta = delta.detach().clone().requires_grad_()

    if target_classes is not None:
        y = target_classes

    # with profile(
    #     activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    #     record_shapes=True,
    #     profile_memory=True,
    # ) as prof:
    with record_function("Forward pass"):
        with torch.set_grad_enabled(True):
            delta.requires_grad_()

            # with torch.autograd.detect_anomaly():
            # print(f"Running inference during attack")
            # print_memory_stats()
            logits = model((X + delta).requires_grad_())
            # print(f"Ran inference during attack")
            # print_memory_stats()
            if zero_shot_classifier is not None:
                logits = 100.0 * logits @ zero_shot_classifier

            if target_classes is not None:
                loss = -loss_fn(logits, target_classes, reduction="none")
            else:
                loss = loss_fn(logits, y, reduction="none")

            grad = torch.autograd.grad(
                loss, delta, grad_outputs=torch.ones_like(loss, device=loss.device)
            )[0]  # Debug: retain_graph=True,  allow_unused=True

    if input_size_one_digit is not None:
        if grad.shape[1] > input_size_one_digit:
            grad[:, input_size_one_digit:] = 0.0

    with record_function("Attack"):
        with torch.no_grad():
            # This uses L2 normalization (instead of sign()) to preserve the relative
            # magnitude of gradients between pixels (i.e. the pixels with larger
            # gradients get proportionally larger updates)
            normgrad = norm(grad).view(-1, 1)
            # normgrad = norm(grad).view(-1, 1, 1, 1)
            if target_classes is not None:
                z = delta - alpha * (grad / (normgrad + 1e-10))
            else:
                z = delta + alpha * (grad / (normgrad + 1e-10))
            # normz = norm(z).view(-1, 1, 1, 1)
            normz = norm(z).view(-1, 1)
            delta = (
                epsilon
                * z
                / torch.max(normz, torch.tensor(epsilon).to(z.device) + 1e-10)
            )

    return delta, loss


def calculate_accuracy(
    net,
    classifier,
    data_loader,
    device,
    N=2000,
    batch_size=50,
    attack_fn=None,
    top_k=1,
    verbose=False,
    **kwargs,
):
    net.eval()
    correct = 0
    total = 0
    successful_attacks = []
    all_attacks = []

    target = kwargs.pop("target_classes", None)
    if target is not None:
        target_classes = target.to(DEVICE)

    for images, labels, img_indices in tqdm(data_loader):
        images = images.clone()
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        if target is None:
            target_class = labels

        with torch.no_grad():
            initial_logits = net(images)
            # print_memory_stats()
            initial_logits = 100.0 * initial_logits @ classifier
            initial_preds = initial_logits.topk(top_k)[1].t()[0]
            initial_correct = initial_preds.eq(labels)

        torch.cuda.empty_cache()  # Add this

        if attack_fn:
            delta = attack_fn(net, images, labels, target_class, **kwargs)
            attack_images = images + delta
        else:
            attack_images = images

        with torch.no_grad():
            # print_memory_stats()
            logits = net(attack_images.to(device))
            logits = 100.0 * logits @ classifier

            # NEW: Store final predictions separately
            final_preds = logits.topk(top_k)[1].t()[0]
            final_correct = final_preds.eq(labels)

            c = 0
            # Track attacks
            for idx in range(len(labels)):
                a = {
                    "img_idx": img_indices[idx],
                    "true_label": labels[idx].item(),
                    "initial_pred": initial_preds[idx].item(),
                    "target_label": target_class[idx].item(),
                    "final_pred": final_preds[idx].item(),
                    # 'original_image': images[idx].detach().cpu().numpy(),
                    # 'attacked_image': attack_images[idx].detach().cpu().numpy(),
                }
                if initial_correct[idx] and final_preds[idx] == target_class[idx]:
                    # print(f"Image initially CORRECT and finally of the class TARGETED attack: {img_indices[idx]}")
                    # Successful targetted attack: Image initially CORRECT and finally of the class TARGETED attack
                    successful_attacks.append(a)
                    all_attacks.append(a)
                elif initial_correct[idx] and not final_correct[idx]:
                    # print(f"Image initially CORRECT and finally INCORRECT: {img_indices[idx]}")
                    # Successful attack: Image initially CORRECT and finally INCORRECT
                    successful_attacks.append(a)
                    all_attacks.append(a)
                elif initial_correct[idx] and final_correct[idx]:
                    # Model correct: Image initially CORRECT and finally CORRECT
                    print(
                        f"Image initially CORRECT and finally CORRECT: {img_indices[idx]}"
                    )
                    all_attacks.append(a)
                elif not initial_correct[idx]:
                    # Model incorrect
                    print(f"Image initially INCORRECT: {img_indices[idx]}")
                    all_attacks.append(a)

            correct += final_correct.sum().item()
            total += len(labels)

        torch.cuda.empty_cache()

    return correct / total, successful_attacks, all_attacks


def save_attack_examples(successful_attacks, save_dir):
    """
    Save the attack examples to individual files without recomputing the attacks.
    """
    # Create directory for saving images if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    for i, attack in enumerate(successful_attacks):
        # Create a new figure for each attack
        fig, axes = plt.subplots(1, 2, figsize=(10, 4))

        # Get the stored images
        orig_img = attack["original_image"]
        attacked_img = attack["attacked_image"]

        # Handle channel ordering
        if len(orig_img.shape) == 3:  # Handle both RGB and grayscale
            orig_img = orig_img.transpose(1, 2, 0)  # Change from CxHxW to HxWxC
            attacked_img = attacked_img.transpose(1, 2, 0)

        # Plot original image
        axes[0].imshow(orig_img)
        axes[0].set_title(
            f'Original\nPredicted: {attack["initial_pred"]}\nTrue: {attack["true_label"]}'
        )
        axes[0].axis("off")

        # Plot attacked image
        axes[1].imshow(attacked_img)
        axes[1].set_title(
            f'Attacked\nPredicted: {attack["final_pred"]}\nTrue: {attack["true_label"]}'
        )
        axes[1].axis("off")

        # Add a main title with attack details
        plt.suptitle(f'Adversarial Attack Example\nImg id: {attack["img_idx"]}')

        # Save the figure
        filename = f'attack_img_{attack["img_idx"]}.png'
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, bbox_inches="tight", dpi=300)
        plt.close()

        print(f"Saved image pair to {filepath}")

        # Optional: limit the number of saved images
        if i >= 19:  # Save maximum 20 examples
            print("Reached maximum number of saved examples (20)")
            break


def run_adversarial_attack(
    model, classifier, dataloader, attack_params, save_file="", attack_fn=pgd_l2_adv
):
    accuracy, successful_attacks, all_attacks = calculate_accuracy(
        model, classifier, dataloader, DEVICE, attack_fn=attack_fn, **attack_params
    )
    if save_file:
        print("Saving attacks")
        torch.save(
            successful_attacks,
            BASE_DIR
            / f"experiments/imagenet/random/{save_file}_successful_attacks.pth",
        )
        torch.save(
            all_attacks,
            BASE_DIR / f"experiments/imagenet/random/{save_file}_all_attacks.pth",
        )

    return accuracy, successful_attacks
