import copy
import json
import uuid

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from adversarial_superposition.constants import DATA_DIR, DEVICE, MODEL_DIR, RESULTS_DIR
from adversarial_superposition.modulo.utils.fourier import apply_fourier, plot_fourier
from adversarial_superposition.modulo.utils.logger import MetricsLogger
from adversarial_superposition.modulo.utils.pgd_attack_modulo import (
    pgd_l2_adv,
    pgd_linf_adv,
)
from adversarial_superposition.modulo.utils.utils import (
    Config,
    create_change_matrices,
    cross_entropy_float16,
    cross_entropy_float32,
    cross_entropy_float64,
    get_dataset,
    get_model,
    is_wandb_initialized,
    log_tensor_as_image,
    print_change_analysis,
)

VERBOSE = True


def run_attack(
    experiment_key,
    cfg,
    grokking_epoch=None,
    attack_type="l2",
    verbose=False,
    save_path=None,
):
    """
    Performs an adversarial attack on a model checkpoint and analyzes the results.

    This function loads a specified model checkpoint and associated datasets. It then
    applies a Projected Gradient Descent (PGD) adversarial attack (either L2 or
    L-infinity norm) to the model. The function evaluates the model's accuracy
    on the original data and its robust accuracy on the adversarially perturbed data,
    logging these metrics to Weights & Biases.

    A core part of the analysis involves identifying data points that were correctly
    classified by the original model but misclassified after the attack
    ("successful attacks"). These successful attacks, along with the original
    corresponding inputs and labels, are saved to disk.

    The function further analyzes the nature of these successful perturbations by:
    1. Creating change matrices that quantify how input features were altered.
    2. Printing an analysis of these change matrices.
    3. Logging a visual representation (mean) of these matrices to Weights & Biases.
    4. Performing a Fourier analysis on the mean change matrix and plotting the results.

    The function primarily operates in a "full batch" mode (when cfg.full_batch is True),
    processing the entire dataset at once for the attack and analysis. A partial
    implementation for non-full_batch mode exists but is less developed.

    Args:
        experiment_key (str): Identifier for the experiment. Used to construct paths
                              for loading datasets, model configurations, and checkpoints.
        cfg (Config): A configuration object containing parameters for the attack
                      and logging, such as `full_batch`, `softmax_precision`,
                      `num_epochs`, `log_frequency`, and `input_size`.
        grokking_epoch (int, optional): The specific epoch number of the model
                                       checkpoint to load. If None, defaults to epoch 0.
        attack_type (str, optional): The type of PGD adversarial attack to perform.
                                     Can be "l2" for L2 norm or "linf" for L-infinity
                                     norm. Defaults to "l2".
        verbose (bool, optional): If True, enables verbose output from the PGD attack
                                  functions. Defaults to False.
        save_path (str, optional): If provided, the composed figure will be saved to this filepath.

    Returns:
        torch.Tensor: A tensor representing the change matrices derived from
                      successful attacks. These matrices show how input features
                      were perturbed to cause misclassifications. The shape is
                      typically (input_size, num_classes_for_pert, num_successful_attacks)
                      or similar, depending on `create_change_matrices` logic.
                      Returns None if not in `cfg.full_batch` mode, as this
                      part of the logic is not fully implemented for batched processing.

    Raises:
        Exception: If an unsupported `attack_type` is provided.
        FileNotFoundError: If required data files (datasets, configs, model checkpoints)
                           are not found at the expected paths.
        AssertionError: If, after calculating robust accuracy, the model's accuracy on
                        the clean training data (`train_acc`) is not greater than 0.9,
                        indicating a potential issue with the model or data loading.
    """
    run_key = f"{uuid.uuid4().hex[:8]}"

    if not is_wandb_initialized():
        wandb.init(project="toy_models_of_addition", tags=["attack"])

    cross_entropy_function = {
        16: cross_entropy_float16,
        32: cross_entropy_float32,
        64: cross_entropy_float64,
    }

    if grokking_epoch:
        post_grokking_epoch = grokking_epoch
    else:
        post_grokking_epoch = 0

    print(f"Using the grokking epoch: {post_grokking_epoch}")
    wandb.log({"grokking_epoch": post_grokking_epoch})

    train_dataset = torch.load(
        DATA_DIR / f"toy_models/{experiment_key}/last_train_loader.pt",
        map_location=DEVICE,
    )
    all_train_data = train_dataset.dataset.data[train_dataset.indices].to(DEVICE)
    all_train_targets = (
        train_dataset.dataset.targets[train_dataset.indices].to(DEVICE).long()
    )

    test_dataset = torch.load(
        DATA_DIR / f"toy_models/{experiment_key}/last_test_loader.pt",
        map_location=DEVICE,
    )
    all_test_data = test_dataset.dataset.data[test_dataset.indices].to(DEVICE)
    all_test_targets = (
        test_dataset.dataset.targets[test_dataset.indices].to(DEVICE).long()
    )

    cfg.train_fraction = 1.0
    all_dataset = get_dataset(cfg)[0]
    all_data = all_dataset.dataset.data[all_dataset.indices].to(DEVICE)
    all_targets = all_dataset.dataset.targets[all_dataset.indices].to(DEVICE).long()

    if cfg.full_batch:
        all_data_loader = DataLoader(all_dataset, batch_size=512)

    with open(RESULTS_DIR / f"toy_models/{experiment_key}/config.json", "r") as f:
        config = json.load(f)
        config = Config().from_dict(config)
        print(
            f"Using the model config from: {RESULTS_DIR / f'toy_models/{experiment_key}/config.json'}"
        )

    model = get_model(config)
    post_grokking_model = copy.deepcopy(model)
    post_grokking_model.load_state_dict(
        torch.load(
            MODEL_DIR
            / f"toy_models/{experiment_key}/last_run_saved_model_checkpoints.pt",
            map_location=DEVICE,
        )[post_grokking_epoch]
    )
    print(
        f"Loaded the model weights from Using the model config from: {MODEL_DIR / f'toy_models/{experiment_key}/last_run_saved_model_checkpoints.pt'}"
    )
    models = {
        "post_grokking_model": (post_grokking_model, post_grokking_epoch),
    }

    logger = MetricsLogger(config.num_epochs, config.log_frequency)

    all_permutation = torch.randperm(all_data.size(0))
    all_shuffled_data = all_data[all_permutation]
    all_shuffled_targets = all_targets[all_permutation]

    test_permutation = torch.randperm(all_test_data.size(0))
    test_shuffled_data = all_test_data[test_permutation]
    test_shuffled_targets = all_test_targets[test_permutation]

    shuffled_data = all_shuffled_data
    shuffled_targets = all_shuffled_targets

    attack_params = {
        "num_iter": 1000,
        "alpha": 0.01,
        # "target_classes": torch.tensor([desired_class] * batch_size).to(DEVICE),
    }

    if not cfg.full_batch:
        correct_preds, robust_preds, total_preds = [0] * 3

        pbar = tqdm(total=len(all_data_loader))
        pbar.set_description("E: - lr: - eps: -")

        for i, (x, y) in enumerate(all_data_loader):
            x = x.to(device=DEVICE)
            y = y.to(device=DEVICE)

            if attack_type == "l2":
                attack_params.update({"epsilon": 0.1})
                train_delta, successful_attack_mask = pgd_l2_adv(
                    model,
                    x,
                    y,
                    None,
                    **attack_params,
                    verbose=verbose,
                    input_size_one_digit=config.input_size,
                )
            elif attack_type == "linf":
                attack_params.update({"epsilon": 0.5})
                train_delta, successful_attack_mask = pgd_linf_adv(
                    model,
                    x,
                    y,
                    None,
                    **attack_params,
                    verbose=verbose,
                    input_size_one_digit=config.input_size,
                )
            else:
                raise Exception(f"attack_type: {attack_type} not supported")

            attack_data = shuffled_data + train_delta

    else:
        for model_name, (m, e) in models.items():
            if attack_type == "l2":
                attack_params.update({"epsilon": 1.0})
                train_delta, successful_attack_mask = pgd_l2_adv(
                    m,
                    shuffled_data,
                    shuffled_targets,
                    None,
                    **attack_params,
                    verbose=verbose,
                    input_size_one_digit=config.input_size,
                )
            elif attack_type == "linf":
                attack_params.update({"epsilon": 0.005})
                train_delta, successful_attack_mask, _ = pgd_linf_adv(
                    m,
                    shuffled_data,
                    shuffled_targets,
                    None,
                    **attack_params,
                    verbose=verbose,
                    input_size_one_digit=config.input_size,
                )
            else:
                raise Exception(f"attack_type: {attack_type} not supported")

            attack_data = shuffled_data + train_delta
            torch.save(
                attack_data,
                DATA_DIR
                / f"toy_models/{experiment_key}/attacked_data_{model_name}_run_{run_key}.pt",
            )
            print(
                f'Saved attack data to: {DATA_DIR / f"toy_models/{experiment_key}/attacked_data_{model_name}_run_{run_key}.pt"}'
            )
            torch.save(
                all_permutation,
                DATA_DIR
                / f"toy_models/{experiment_key}/permutation_{model_name}_run_{run_key}.pt",
            )
            print(
                f'Saved permutations data to: {DATA_DIR / f"toy_models/{experiment_key}/permutation_{model_name}_run_{run_key}.pt"}'
            )

            # First calculate the robust accuracy
            logger.log_metrics(
                model=m,
                epoch=e,
                save_model_checkpoints=[],
                saved_models=None,
                all_data=shuffled_data,
                all_targets=shuffled_targets,
                all_test_data=attack_data,
                all_test_targets=shuffled_targets,
                args=config,
                loss_function=cross_entropy_function[config.softmax_precision],
            )

            train_acc = logger.metrics_df[
                (logger.metrics_df["metric_name"] == "accuracy")
                & (logger.metrics_df["input_type"] == "train")
            ].iloc[-1]["value"]
            # The test accuracy here is actually the robust accuracy
            robust_acc = logger.metrics_df[
                (logger.metrics_df["metric_name"] == "accuracy")
                & (logger.metrics_df["input_type"] == "test")
            ].iloc[-1]["value"]
            print(
                f"{model_name} (Epoch: {e}): Train accuracy: {train_acc}; Robust accuracy: {robust_acc}"
            )

        # wandb.log(process_metrics(logger.metrics_df))
        assert train_acc > 0.9, f"The test accuracy was only {train_acc}"
        wandb.log(
            {
                "robust_accuracy": robust_acc,
                "whole_dataset_accuracy": train_acc,
            }
        )

        clean_preds = m(shuffled_data).argmax(dim=1)
        attack_preds = m(attack_data).argmax(dim=1)
        clean_pred_correct_indices = clean_preds == shuffled_targets
        attack_successful_indices = attack_preds != shuffled_targets
        orig_inputs_for_successful_attacks = shuffled_data[
            attack_successful_indices & clean_pred_correct_indices
        ]
        orig_labels_for_successful_attacks = shuffled_targets[
            attack_successful_indices & clean_pred_correct_indices
        ]
        successful_attacks = attack_data[
            attack_successful_indices & clean_pred_correct_indices
        ]
        successful_attack_targets = attack_preds[
            attack_successful_indices & clean_pred_correct_indices
        ]

        torch.save(
            successful_attacks,
            DATA_DIR
            / f"toy_models/{experiment_key}/successful_attack_data_{model_name}_run_{run_key}.pt",
        )
        print(
            f'Saved attack data to: {DATA_DIR / f"toy_models/{experiment_key}/successful_attack_data_{model_name}_run_{run_key}.pt"}'
        )
        torch.save(
            orig_inputs_for_successful_attacks,
            DATA_DIR
            / f"toy_models/{experiment_key}/original_data_{model_name}_run_{run_key}.pt",
        )
        print(
            f'Saved original data to: {DATA_DIR / f"toy_models/{experiment_key}/original_data_{model_name}_run_{run_key}.pt"}'
        )
        torch.save(
            orig_labels_for_successful_attacks,
            DATA_DIR
            / f"toy_models/{experiment_key}/original_labels_{model_name}_run_{run_key}.pt",
        )
        print(
            f'Saved original labels to: {DATA_DIR / f"toy_models/{experiment_key}/original_labels_{model_name}_run_{run_key}.pt"}'
        )

        matrices, coverage = create_change_matrices(
            orig_inputs_for_successful_attacks, successful_attacks, cfg
        )
        print_change_analysis(matrices, coverage)
        # The interpretation of the first column is: Given that the sum is 0 + A,
        # which value should be increased such that we have an adversary.
        log_tensor_as_image(matrices[: config.input_size, :, :].mean(-1))

        # Custom gridspec: left is tall attack, right is two stacked bar charts
        fig = plt.figure(figsize=(14, 7))
        gs = fig.add_gridspec(1, 2, width_ratios=[1, 1.1], wspace=0.25)
        # Left: Attack matrix (full height, square aspect)
        ax_attack = fig.add_subplot(gs[0, 0])
        # Right: Two stacked bar charts (Fourier), with more white space
        gs_right = gs[0, 1].subgridspec(2, 1, height_ratios=[0.85, 0.85], hspace=0.35)
        ax_fourier_deltas = fig.add_subplot(gs_right[0, 0])
        ax_fourier_weights = fig.add_subplot(gs_right[1, 0])

        # --- Attack matrix as RGB (left, square) ---
        attack_matrix = (
            matrices[: config.input_size, :, :].mean(-1).detach().cpu().numpy()
        )
        original_signs = np.sign(attack_matrix)
        zero_mask = np.abs(attack_matrix) < 1e-6
        pos_tensor = np.where(attack_matrix > 0, attack_matrix, 0)
        neg_tensor = np.where(attack_matrix < 0, -attack_matrix, 0)
        if pos_tensor.max() > 0:
            pos_tensor = 128 + (pos_tensor / pos_tensor.max()) * 127
        if neg_tensor.max() > 0:
            neg_tensor = 128 + (neg_tensor / neg_tensor.max()) * 127
        rgb_tensor = np.zeros((*attack_matrix.shape, 3), dtype=np.uint8)
        rgb_tensor[..., 0] = np.where(original_signs > 0, pos_tensor, 0).astype(
            np.uint8
        )
        rgb_tensor[..., 2] = np.where(original_signs < 0, neg_tensor, 0).astype(
            np.uint8
        )
        rgb_tensor[zero_mask] = 128
        ax_attack.imshow(rgb_tensor, aspect="equal")
        ax_attack.set_xlabel("a", fontsize=20)
        ax_attack.set_ylabel(r"Adversarial perturbation ($\delta$)", fontsize=20)
        ax_attack.set_title("Adversarial Attack Matrix", fontsize=20)
        # Make the attack matrix axis square
        ax_attack.set_aspect("equal", adjustable="box")

        # --- Fourier of deltas (top right) ---
        attack_matrix_torch = (
            matrices[: config.input_size, :, :].mean(-1).detach().cpu()
        )
        fft_results = apply_fourier(attack_matrix_torch)
        plot_fourier(fft_results, ax=ax_fourier_deltas, return_fig_ax=False)
        ax_fourier_deltas.set_title("Fourier Spectrum (Deltas)", fontsize=18)
        ax_fourier_deltas.set_xlabel("")  # Remove x-axis label for top bar chart

        # --- Fourier of weights (bottom right) ---
        fft_results_weights = apply_fourier(m.layers[-1].weight.detach().cpu()[:, :])
        plot_fourier(fft_results_weights, ax=ax_fourier_weights, return_fig_ax=False)
        ax_fourier_weights.set_title("Fourier Spectrum (Weights)", fontsize=18)
        ax_fourier_weights.set_xlabel(
            "Frequency", fontsize=14
        )  # Only bottom plot gets x-label

        # Shrink the bar charts a bit to add more white space between them
        for ax in [ax_fourier_deltas, ax_fourier_weights]:
            pos = ax.get_position()
            ax.set_position(
                [pos.x0, pos.y0 + 0.03, pos.width * 0.92, pos.height * 0.92]
            )

        # Save individual plots
        if save_path is not None:
            # Save attack matrix plot
            attack_fig = plt.figure(figsize=(8, 8))
            attack_ax = attack_fig.add_subplot(111)
            attack_ax.imshow(rgb_tensor, aspect="equal")
            attack_ax.set_xlabel("a", fontsize=20)
            attack_ax.set_ylabel(r"Adversarial perturbation ($\delta$)", fontsize=20)
            attack_ax.set_title("Adversarial Attack Matrix", fontsize=20)
            attack_ax.set_aspect("equal", adjustable="box")
            attack_fig.savefig(
                str(save_path).replace(".pdf", "_modulo_pt1.pdf"), bbox_inches="tight"
            )
            plt.close(attack_fig)

            # Save Fourier deltas plot
            fourier_deltas_fig = plt.figure(figsize=(8, 4))
            fourier_deltas_ax = fourier_deltas_fig.add_subplot(111)
            plot_fourier(fft_results, ax=fourier_deltas_ax, return_fig_ax=False)
            fourier_deltas_ax.set_title("Fourier Spectrum (Deltas)", fontsize=18)
            fourier_deltas_fig.savefig(
                str(save_path).replace(".pdf", "_modulo_pt2.pdf"), bbox_inches="tight"
            )
            plt.close(fourier_deltas_fig)

            # Save Fourier weights plot
            fourier_weights_fig = plt.figure(figsize=(8, 4))
            fourier_weights_ax = fourier_weights_fig.add_subplot(111)
            plot_fourier(
                fft_results_weights, ax=fourier_weights_ax, return_fig_ax=False
            )
            fourier_weights_ax.set_title("Fourier Spectrum (Weights)", fontsize=18)
            fourier_weights_ax.set_xlabel("Frequency", fontsize=14)
            fourier_weights_fig.savefig(
                str(save_path).replace(".pdf", "_modulo_pt3.pdf"), bbox_inches="tight"
            )
            plt.close(fourier_weights_fig)

            # Save combined figure
            fig.savefig(save_path, bbox_inches="tight")
        return matrices, fig, [ax_attack, ax_fourier_deltas, ax_fourier_weights]


if __name__ == "__main__":
    cfg = Config(
        log_frequency=300,
        train_fraction=1.0,
        full_batch=True,
    )
    experiment_key = "2b3a80e3"
    run_attack(
        experiment_key,
        cfg,
        grokking_epoch=7_800,
        attack_type="l2",
        verbose=True,
    )
