import os
from dataclasses import dataclass, field
from multiprocessing import freeze_support
from typing import Dict, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.onnx
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10

import wandb
from adversarial_superposition.cifar.utils.load_cifar import load_vit_model
from adversarial_superposition.cifar.utils.pgd_attack import (
    pgd_l2_adv,
    pgd_linf_adv,
    run_adversarial_attack,
)
from adversarial_superposition.constants import DATA_DIR, DEVICE, MODEL_DIR, RESULTS_DIR


@dataclass
class ExperimentConfig:
    """Configuration settings for the robustness evaluation experiment."""

    image_size: int = field(default=32)
    epsilon: float = field(default=8.0 / 255.0)
    num_iter: int = field(default=100)
    alpha: float = field(default=2.0 / 255.0)
    source_class: int = field(default=None)
    target_class: int = field(default=None)
    attack_type: str = field(default="l2")
    bottleneck_dim: int = field(default=0)
    bottleneck_after_dim: int = field(default=0)
    batch_size: int = field(default=16)
    num_workers: int = field(default=4)
    model_path: str = field(
        default=MODEL_DIR / "scale_models/vit-bottleneck-vit-4-seed20-ckpt.t7"
    )
    project_name: str = field(default="cifar10-robustness-vit")
    num_attack_samples: Optional[int] = field(
        default=None
    )  # Number of samples to attack (subset), None for all


class IndexedCIFAR10(CIFAR10):
    """Wrapper for CIFAR10 to return index along with image and target."""

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, index


def load_cifar_data(image_size, batch_size, num_workers, data_dir, return_indexed=True):
    """Loads CIFAR-10 test dataset and creates a DataLoader."""
    print("Loading CIFAR-10 dataset...")
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    DatasetClass = IndexedCIFAR10 if return_indexed else CIFAR10
    dataset_path = str(data_dir / "cifar10")
    cifar_test_data = DatasetClass(
        root=dataset_path, train=False, download=True, transform=transform
    )

    dataloader = DataLoader(
        cifar_test_data,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=False,  # Keep shuffle False for reproducibility and subsetting
    )
    print(f"Created DataLoader with {len(cifar_test_data)} samples.")
    return dataloader, cifar_test_data  # Return dataset object too


def main(
    config_input: Union[Dict, ExperimentConfig, object],
    enable_wandb: bool = True,
    run_name: Optional[str] = None,
):
    """Runs the CIFAR-10 robustness evaluation using a configuration object/dictionary.

    Args:
        config_input: An ExperimentConfig object or a dictionary containing the parameters.
            See ExperimentConfig dataclass definition for parameters and defaults.
        enable_wandb: If True, initialize and log to WandB.
        run_name: Optional name for the WandB run (overrides config if provided).
    """
    # --- Process Config Input ---
    if isinstance(config_input, dict):
        # If dict is provided, create ExperimentConfig instance, filling missing with defaults
        config = ExperimentConfig(**config_input)
        config_dict = config.__dict__  # Get dict view from dataclass
    elif isinstance(config_input, ExperimentConfig):
        # If ExperimentConfig object is provided, use it directly
        config = config_input
        config_dict = config.__dict__
    else:
        print(
            "Warning: config_input is not a dict or ExperimentConfig. Attempting to use as object."
        )
        config = config_input
        if not hasattr(config, "__dict__"):
            raise ValueError("Provided config_input object does not have __dict__")
        config_dict = config.__dict__

    if not enable_wandb:
        run_id = "local_run"
        print("WandB logging disabled.")
        # Use the config object created above directly
    else:
        run = wandb.init(project="cifar10-robustness-vit", config=config_dict)
        # Update config object from wandb.config (in case of sweep overrides)
        # This ensures the rest of the script uses sweep-provided values
        config = ExperimentConfig(**wandb.config)
        run_id = run.id
        print(f"WandB run initialized: {run.get_url()}")

    # --- Load Model ---
    model = load_vit_model(
        model_path=config.model_path,
        bottleneck_dim=config.bottleneck_dim,
        bottleneck_after_dim=config.bottleneck_after_dim,
        device=DEVICE,
    )

    # --- Load Data ---
    test_dataloader, cifar_test_data = load_cifar_data(
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        data_dir=DATA_DIR,
        return_indexed=True,
    )

    # --- Targeted Attack Setup ---
    if config.source_class is not None:
        source_class_indices = [
            i
            for i, target in enumerate(cifar_test_data.targets)
            if target == config.source_class
        ]
        # Use num_attack_samples if provided
        if config.num_attack_samples is not None:
            source_class_indices = source_class_indices[: config.num_attack_samples]
            print(f"Subsetting to {len(source_class_indices)} samples for attack.")
        cifar_test_data = Subset(cifar_test_data, source_class_indices)
        print(
            f"Created source DataLoader with {len(cifar_test_data)} samples of class {config.source_class}."
        )
    elif config.num_attack_samples is not None:
        # If source_class is None, but num_attack_samples is specified, take the first N samples
        all_indices = list(range(len(cifar_test_data)))
        attack_indices = all_indices[: config.num_attack_samples]
        cifar_test_data = Subset(cifar_test_data, attack_indices)
        print(
            f"Subsetting to the first {len(cifar_test_data)} samples for attack (no source_class specified)."
        )

    source_dataloader = DataLoader(
        cifar_test_data,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        pin_memory=True,
        shuffle=False,
    )
    print(f"Created source DataLoader with {len(cifar_test_data)}")

    # --- Run PGD Targeted Attack ---
    attack_params = {
        "epsilon": config.epsilon,
        "num_iter": config.num_iter,
        "alpha": config.alpha,
        # Ensure batch size from config is used here
        "target_classes": (
            torch.full(
                (config.batch_size,),
                config.target_class,
                dtype=torch.long,
                device=DEVICE,
            )
            if config.target_class is not None
            else None
        ),
        "continue_after_success": False,
    }

    # Determine attack function based on config
    if config.attack_type == "linf":
        attack_fn = pgd_linf_adv
        norm_str = "linf"
    elif config.attack_type == "l2":
        attack_fn = pgd_l2_adv
        norm_str = "l2"
    else:
        raise ValueError(f"Unsupported attack type: {config.attack_type}")

    # Use run id (wandb or local) for unique file naming
    model_name = os.path.basename(config.model_path).split(".")[0]
    save_filename = f"adv_attacks_cifar10_vit_{model_name}_source_{config.source_class}_target_{config.target_class}_{norm_str}_eps_{config.epsilon:.4f}_alpha_{config.alpha:.4f}_iter_{config.num_iter}.h5"
    save_path = RESULTS_DIR / f"cifar10/evaluate_robustness/{save_filename}"
    os.makedirs(save_path.parent, exist_ok=True)

    print(
        f"Running Targeted PGD {norm_str.upper()} attack (Source: {config.source_class}, Target: {config.target_class})..."
    )
    # Use config for attack params display consistency
    print(
        f"Attack parameters: eps={config.epsilon}, iter={config.num_iter}, alpha={config.alpha}, type={config.attack_type}"
    )
    print(f"Saving results to: {save_path}")

    attack_results = run_adversarial_attack(
        model=model,
        classifier=None,
        dataloader=source_dataloader,
        attack_params=attack_params,
        save_file=str(save_path),
        attack_fn=attack_fn,
    )

    stats = {
        "attack_success_rate": attack_results["attack_success_rate"],
        "robust_accuracy": attack_results["robust_accuracy"],
        "clean_accuracy_subset": attack_results["clean_accuracy"],
        "total_samples": attack_results["total_samples"],
        "normalised_robust_accuracy": attack_results["robust_accuracy"]
        / attack_results["clean_accuracy"],
    }
    print(stats)

    if enable_wandb:
        wandb.log(stats)

    print("Attack finished.")
    if enable_wandb:
        wandb.finish()

    return stats


if __name__ == "__main__":
    freeze_support()

    n_iter = 100
    num_samples = 500

    # Dictionary to collect results for plotting
    bottleneck_results = {}
    seed_results = {10: {}, 20: {}, 30: {}, 40: {}}

    # Dictionary to store attack filepaths for transferability analysis
    attack_filepaths = {"linf": {}, "l2": {}}

    # Define attack configurations
    attack_configs = {
        "linf": {"epsilons": [0.001, 0.01, 0.05, 0.1], "alpha": 0.01},
        "l2": {"epsilons": [0.1, 0.5, 1.0, 2.0, 5.0], "alpha": 0.01},
    }

    # Create results table headers
    print("\nResults Summary:")
    print("=" * 100)
    print(
        f"{'Seed':<6} {'BN Dim':<8} {'Attack':<6} {'Epsilon':<8} {'Clean Acc':<10} {'Robust Acc':<12} {'Norm Acc':<10}"
    )
    print("-" * 100)

    # Run experiments for each attack type and epsilon
    for attack_type, config in attack_configs.items():
        attack_filepaths[attack_type] = {}
        for epsilon in config["epsilons"]:
            attack_filepaths[attack_type][epsilon] = {}
            print(f"\nRunning {attack_type.upper()} attack with epsilon {epsilon}")

            seed10_models = {
                2: MODEL_DIR
                / "scale_models/seed10/vit_finetune_bn_2_seed_10_vit-finetune-p4-seed10-final-ckpt.t7",
                3: MODEL_DIR
                / "scale_models/seed10/vit_finetune_bn_3_seed_10_vit-finetune-p4-seed10-final-ckpt.t7",
                5: MODEL_DIR
                / "scale_models/seed10/vit_finetune_bn_5_seed_10_vit-finetune-p4-seed10-final-ckpt.t7",
                10: MODEL_DIR
                / "scale_models/seed10/vit_finetune_bn_10_seed_10_vit-finetune-p4-seed10-final-ckpt.t7",
            }

            seed20_models = {
                2: MODEL_DIR / "cifar/seed20/02_vit-finetune-p4-seed20-final-ckpt.t7",
                3: MODEL_DIR
                / "cifar/seed20/vit_finetune_bottleneck_3_see_20_vit-finetune-p4-seed20-final-ckpt.t7",
                5: MODEL_DIR
                / "cifar/seed20/vit_finetune_bottleneck_5_seed_20_vit-finetune-p4-seed20-final-ckpt.t7",
                10: MODEL_DIR
                / "cifar/seed20/vit_finetune_bottleneck_10_seed_20_vit-finetune-p4-seed20-final-ckpt.t7",
            }

            seed30_models = {
                2: MODEL_DIR
                / "scale_models/seed30/03_vit-finetune-p4-seed30-final-ckpt.t7",
                3: MODEL_DIR
                / "scale_models/seed30/vit_finetune_bottleneck_3_vit-finetune-p4-seed30-ckpt.t7",
                5: MODEL_DIR
                / "scale_models/seed30/vit_finetune_bottleneck_5_seed_30_vit-finetune-p4-seed30-final-ckpt.t7",
                10: MODEL_DIR
                / "scale_models/seed30/vit_finetune_bottleneck_10_seed_30_vit-finetune-p4-seed30-final-ckpt.t7",
            }

            seed40_models = {
                2: MODEL_DIR
                / "scale_models/seed40/vit_finetune_bn_2_seed_40_vit-finetune-p4-seed40-final-ckpt.t7",
                3: MODEL_DIR
                / "scale_models/seed40/vit_finetune_bn_3_seed_40_vit-finetune-p4-seed40-ckpt.t7",
                5: MODEL_DIR
                / "scale_models/seed40/vit_finetune_bn_5_seed_40_vit-finetune-p4-seed40-final-ckpt.t7",
                10: MODEL_DIR
                / "scale_models/seed40/vit_finetune_bn_10_seed_40_vit-finetune-p4-seed40-final-ckpt.t7",
            }

            # Run experiments for each seed
            for seed, models in [
                (10, seed10_models),
                (20, seed20_models),
                (30, seed30_models),
                (40, seed40_models),
            ]:
                print(f"\nRunning experiments for seed {seed}")
                for bn_dim, model_path in models.items():
                    print(f"\nTesting bottleneck dimension {bn_dim}")
                    cli_config = ExperimentConfig(
                        model_path=model_path,
                        bottleneck_dim=bn_dim,
                        bottleneck_after_dim=0,
                        epsilon=epsilon,
                        num_iter=n_iter,
                        alpha=config["alpha"],
                        num_attack_samples=num_samples,
                        attack_type=attack_type,
                        batch_size=64,
                        num_workers=4,
                    )
                    stats = main(config_input=cli_config)
                    seed_results[seed][bn_dim] = stats

                    # Store the attack filepath

                    model_name = os.path.basename(model_path).split(".")[0]
                    save_filename = f"adv_attacks_cifar10_vit_{model_name}_source_None_target_None_{attack_type}_eps_{epsilon:.4f}_alpha_{config['alpha']:.4f}_iter_{n_iter}.h5"
                    save_path = (
                        RESULTS_DIR / f"cifar10/evaluate_robustness/{save_filename}"
                    )
                    attack_filepaths[attack_type][epsilon][(seed, bn_dim)] = str(
                        save_path
                    )

                    # Print results in table format
                    print(
                        f"{seed:<6} {bn_dim:<8} {attack_type:<6} {epsilon:<8.3f} "
                        f"{stats['clean_accuracy_subset']:<10.4f} "
                        f"{stats['robust_accuracy']:<12.4f} "
                        f"{stats['normalised_robust_accuracy']:<10.4f}"
                    )

                    # Log to wandb
                    if wandb.run is not None:
                        wandb.log(
                            {
                                "seed": seed,
                                "bottleneck_dim": bn_dim,
                                "attack_type": attack_type,
                                "epsilon": epsilon,
                                "clean_accuracy": stats["clean_accuracy_subset"],
                                "robust_accuracy": stats["robust_accuracy"],
                                "normalised_robust_accuracy": stats[
                                    "normalised_robust_accuracy"
                                ],
                            }
                        )

            # Calculate statistics across seeds for each bottleneck dimension
            for bn_dim in [2, 3, 5, 10]:
                normalized_accuracies = [
                    seed_results[seed][bn_dim]["normalised_robust_accuracy"]
                    for seed in [10, 20, 30, 40]
                ]
                mean_acc = np.mean(normalized_accuracies)
                std_acc = np.std(normalized_accuracies)
                bottleneck_results[bn_dim] = {
                    "mean": mean_acc,
                    "std": std_acc,
                    "values": normalized_accuracies,
                }
                print(f"\nBottleneck dimension {bn_dim} ({attack_type}, ε={epsilon}):")
                print(
                    f"Mean normalized robust accuracy: {mean_acc:.4f} ± {std_acc:.4f}"
                )
                print(
                    f"Individual values: {[f'{v:.4f}' for v in normalized_accuracies]}"
                )

    print("\n" + "=" * 100)
    print("Experiment Summary Complete")
    print("=" * 100)

    # Save attack filepaths to a JSON file
    import json

    filepaths_save_path = (
        RESULTS_DIR / "cifar10/evaluate_robustness/attack_filepaths.json"
    )

    # Convert tuple keys to strings for JSON serialization
    json_attack_filepaths = {}
    for attack_type in attack_filepaths:
        json_attack_filepaths[attack_type] = {}
        for epsilon in attack_filepaths[attack_type]:
            json_attack_filepaths[attack_type][str(epsilon)] = {}
            for (seed, bn_dim), filepath in attack_filepaths[attack_type][
                epsilon
            ].items():
                key = f"seed{seed}_bn{bn_dim}"
                json_attack_filepaths[attack_type][str(epsilon)][key] = filepath

    with open(filepaths_save_path, "w") as f:
        json.dump(json_attack_filepaths, f, indent=2)
    print(f"\nAttack filepaths saved to: {filepaths_save_path}")

    # Create a plot of normalized robust accuracy vs bottleneck dimension with error bars
    if bottleneck_results:
        print(
            "\n--- Creating bottleneck dimension vs robust accuracy plot with error bars ---"
        )

        # Convert results to DataFrame
        df = pd.DataFrame(
            {
                "bottleneck_dim": list(bottleneck_results.keys()),
                "mean_accuracy": [
                    results["mean"] for results in bottleneck_results.values()
                ],
                "std_accuracy": [
                    results["std"] for results in bottleneck_results.values()
                ],
            }
        )

        # Sort by bottleneck dimension
        df = df.sort_values("bottleneck_dim")

        # Create the plot
        plt.figure(figsize=(10, 6))
        plt.errorbar(
            df["bottleneck_dim"],
            df["mean_accuracy"],
            yerr=df["std_accuracy"],
            fmt="o-",
            linewidth=2,
            markersize=10,
            capsize=10,
        )
        plt.xlabel("Bottleneck Dimension", fontsize=14)
        plt.ylabel("Normalized Robust Accuracy", fontsize=14)
        plt.title(
            "Normalized Robust Accuracy vs Bottleneck Dimension (with error bars)",
            fontsize=16,
        )
        plt.grid(True, linestyle="--", alpha=0.7)
        plt.xticks(df["bottleneck_dim"], fontsize=12)
        plt.yticks(fontsize=12)

        # Save the plot
        plot_save_path = (
            RESULTS_DIR
            / f"cifar10/evaluate_robustness/bottleneck_robustness_plot_eps_{epsilon:.4f}_with_error_bars.png"
        )
        plt.savefig(str(plot_save_path), dpi=300, bbox_inches="tight")
        print(f"Plot saved to {plot_save_path}")

        # Save the data
        data_save_path = (
            RESULTS_DIR
            / f"cifar10/evaluate_robustness/bottleneck_robustness_data_eps_{epsilon:.4f}_with_error_bars.csv"
        )
        df.to_csv(data_save_path, index=False)
        print(f"Data saved to {data_save_path}")

        plt.close()

        print("--- Bottleneck dimension analysis complete ---")
