import json
from typing import Dict

import numpy as np
import torch
from tabulate import tabulate

from adversarial_superposition.cifar.utils.load_cifar import load_vit_model
from adversarial_superposition.constants import DEVICE, MODEL_DIR, RESULTS_DIR


def get_h5_lazy_loader(h5_path):
    """Create a lazy loader for H5 file data that only loads data when accessed."""
    import h5py

    class H5LazyLoader:
        def __init__(self, file_path):
            self.file_path = file_path
            self.file = h5py.File(file_path, "r")

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            self.close()

        def close(self):
            """Close the H5 file."""
            self.file.close()

        def get_batch(self, dataset_path, start_idx, batch_size):
            """Load a batch of data from a specific dataset."""
            return self.file[dataset_path][start_idx : start_idx + batch_size]

        def get_dataset_shape(self, dataset_path):
            """Get the shape of a dataset without loading it."""
            return self.file[dataset_path].shape

        def iter_batches(self, dataset_path, batch_size):
            """Create an iterator over batches of data."""
            dataset_size = self.get_dataset_shape(dataset_path)[0]
            for start_idx in range(0, dataset_size, batch_size):
                yield self.get_batch(dataset_path, start_idx, batch_size)

        @property
        def available_datasets(self):
            """List all available datasets in the file."""
            datasets = []

            def collect_datasets(name, obj):
                if isinstance(obj, h5py.Dataset):
                    datasets.append(name)

            self.file.visititems(collect_datasets)
            return datasets

    return H5LazyLoader(h5_path)


def _get_data_from_h5_batch(h5_loader, batch_start, n_samples, batch_size):
    """Get the data required to calculate the metrics for a batch of adversarial examples."""
    batch_end = min(batch_start + batch_size, n_samples)
    current_batch_size = batch_end - batch_start

    batch_meta = h5_loader.get_batch(
        "successful_attacks/metadata", batch_start, current_batch_size
    )
    org_images = torch.tensor(
        h5_loader.get_batch(
            "successful_attacks/original_images",
            batch_start,
            current_batch_size,
        )
    ).to(DEVICE)
    attacked_images = torch.tensor(
        h5_loader.get_batch(
            "successful_attacks/attacked_images",
            batch_start,
            current_batch_size,
        )
    ).to(DEVICE)

    return batch_meta, org_images, attacked_images


def analyze_transferability(
    source_model_path: str,
    target_model_path: str,
    h5_path: str,
    source_bn_dim: int,
    target_bn_dim: int,
    batch_size: int = 1000,
) -> Dict[str, float]:
    """Analyze transferability between two models."""
    # Load models
    source_model = load_vit_model(
        model_path=source_model_path,
        bottleneck_dim=source_bn_dim,
        bottleneck_after_dim=0,
        device=DEVICE,
    )
    target_model = load_vit_model(
        model_path=target_model_path,
        bottleneck_dim=target_bn_dim,
        bottleneck_after_dim=0,
        device=DEVICE,
    )

    # Load attack data
    with get_h5_lazy_loader(h5_path) as loader:
        n_samples = loader.get_dataset_shape("successful_attacks/metadata")[0]
        batch_meta, org_images, attacked_images = _get_data_from_h5_batch(
            loader, 0, n_samples, batch_size
        )

    # Calculate transferability metrics
    transfers = 0
    same_transfers = 0
    for img in range(n_samples):
        attack = batch_meta[img]
        assert attack["final_pred"] != attack["target_label"]

        true_class_idx = attack["true_label"]
        attacked_class_idx = attack["final_pred"]

        orig_pred = source_model(org_images[img].unsqueeze(0)).argmax().item()
        assert orig_pred == true_class_idx  # Assert that the original model is correct

        attacked_pred = source_model(attacked_images[img].unsqueeze(0)).argmax().item()
        assert orig_pred != attacked_pred  # Assert that the attacked model is incorrect

        other_model_attacked_pred = (
            target_model(attacked_images[img].unsqueeze(0)).argmax().item()
        )
        if other_model_attacked_pred != true_class_idx:
            transfers += 1

        if other_model_attacked_pred == attacked_pred:
            same_transfers += 1

    return {
        "n_samples": n_samples,
        "transfers": transfers,
        "same_transfers": same_transfers,
        "transfer_rate": transfers / n_samples,
        "same_transfer_rate": same_transfers / n_samples,
    }


def main():
    # Load attack filepaths
    filepaths_path = RESULTS_DIR / "cifar10/evaluate_robustness/attack_filepaths.json"
    with open(filepaths_path, "r") as f:
        attack_filepaths = json.load(f)

    # Define model paths for different seeds
    model_paths = {
        10: {
            2: MODEL_DIR
            / "scale_models/seed10/vit_finetune_bn_2_seed_10_vit-finetune-p4-seed10-final-ckpt.t7",
            3: MODEL_DIR
            / "scale_models/seed10/vit_finetune_bn_3_seed_10_vit-finetune-p4-seed10-final-ckpt.t7",
            5: MODEL_DIR
            / "scale_models/seed10/vit_finetune_bn_5_seed_10_vit-finetune-p4-seed10-final-ckpt.t7",
            10: MODEL_DIR
            / "scale_models/seed10/vit_finetune_bn_10_seed_10_vit-finetune-p4-seed10-final-ckpt.t7",
        },
        20: {
            2: MODEL_DIR
            / "scale_models/seed20/vit_finetune_bn_2_seed_20_vit-finetune-p4-seed20-final-ckpt.t7",
            3: MODEL_DIR
            / "scale_models/seed20/vit_finetune_bn_3_seed_20_vit-finetune-p4-seed20-final-ckpt.t7",
            5: MODEL_DIR
            / "scale_models/seed20/vit_finetune_bn_5_seed_20_vit-finetune-p4-seed20-final-ckpt.t7",
            10: MODEL_DIR
            / "scale_models/seed20/vit_finetune_bn_10_seed_20_vit-finetune-p4-seed20-final-ckpt.t7",
        },
        30: {
            2: MODEL_DIR
            / "scale_models/seed30/vit_finetune_bn_2_seed_30_vit-finetune-p4-seed30-final-ckpt.t7",
            3: MODEL_DIR
            / "scale_models/seed30/vit_finetune_bn_3_seed_30_vit-finetune-p4-seed30-final-ckpt.t7",
            5: MODEL_DIR
            / "scale_models/seed30/vit_finetune_bn_5_seed_30_vit-finetune-p4-seed30-final-ckpt.t7",
            10: MODEL_DIR
            / "scale_models/seed30/vit_finetune_bn_10_seed_30_vit-finetune-p4-seed30-final-ckpt.t7",
        },
        40: {
            2: MODEL_DIR
            / "scale_models/seed40/vit_finetune_bn_2_seed_40_vit-finetune-p4-seed40-final-ckpt.t7",
            3: MODEL_DIR
            / "scale_models/seed40/vit_finetune_bn_3_seed_40_vit-finetune-p4-seed40-final-ckpt.t7",
            5: MODEL_DIR
            / "scale_models/seed40/vit_finetune_bn_5_seed_40_vit-finetune-p4-seed40-final-ckpt.t7",
            10: MODEL_DIR
            / "scale_models/seed40/vit_finetune_bn_10_seed_40_vit-finetune-p4-seed40-final-ckpt.t7",
        },
    }

    # Store results
    results = {"linf": {}, "l2": {}}

    # Analyze transferability for each attack type and epsilon
    for attack_type in ["linf", "l2"]:
        results[attack_type] = {}
        for epsilon in attack_filepaths[attack_type].keys():
            results[attack_type][epsilon] = {}
            print(f"\nAnalyzing {attack_type.upper()} attacks with epsilon {epsilon}")

            # Create table headers
            table_data = []
            headers = [
                "Source Seed",
                "Target Seed",
                "BN Dim",
                "Transfer Rate",
                "Same Transfer Rate",
            ]

            # Analyze transferability between models with the same bottleneck dimension
            for bn_dim in [2, 3, 5, 10]:
                # Store results for this bottleneck dimension
                bn_results = []

                # Test transferability between all pairs of seeds
                for source_seed in [10, 20, 30, 40]:
                    for target_seed in [10, 20, 30, 40]:
                        if source_seed == target_seed:
                            continue  # Skip same seed

                        source_path = model_paths[source_seed][bn_dim]
                        target_path = model_paths[target_seed][bn_dim]
                        h5_path = attack_filepaths[attack_type][epsilon][
                            f"seed{source_seed}_bn{bn_dim}"
                        ]

                        metrics = analyze_transferability(
                            source_model_path=source_path,
                            target_model_path=target_path,
                            h5_path=h5_path,
                            source_bn_dim=bn_dim,
                            target_bn_dim=bn_dim,
                        )

                        # Store results
                        key = f"seed{source_seed}_to_seed{target_seed}"
                        results[attack_type][epsilon][f"bn{bn_dim}_{key}"] = metrics
                        bn_results.append(metrics)

                        # Add to table data
                        table_data.append(
                            [
                                source_seed,
                                target_seed,
                                bn_dim,
                                f"{metrics['transfer_rate']:.4f}",
                                f"{metrics['same_transfer_rate']:.4f}",
                            ]
                        )

                # Calculate summary statistics for this bottleneck dimension
                transfer_rates = [r["transfer_rate"] for r in bn_results]
                same_transfer_rates = [r["same_transfer_rate"] for r in bn_results]

                summary = {
                    "mean_transfer_rate": np.mean(transfer_rates),
                    "std_transfer_rate": np.std(transfer_rates),
                    "mean_same_transfer_rate": np.mean(same_transfer_rates),
                    "std_same_transfer_rate": np.std(same_transfer_rates),
                    "n_pairs": len(bn_results),
                }

                results[attack_type][epsilon][f"bn{bn_dim}_summary"] = summary

            # Print table
            print(f"\nTransferability Results for {attack_type.upper()} (ε={epsilon}):")
            print(tabulate(table_data, headers=headers, tablefmt="grid"))

            # Print summary statistics
            print("\nSummary Statistics:")
            print("=" * 80)
            print(
                f"{'BN Dim':<8} {'Mean Transfer Rate':<20} {'Std Transfer Rate':<20} {'Mean Same Transfer Rate':<25} {'Std Same Transfer Rate':<25}"
            )
            print("-" * 80)
            for bn_dim in [2, 3, 5, 10]:
                summary = results[attack_type][epsilon][f"bn{bn_dim}_summary"]
                print(
                    f"{bn_dim:<8} {summary['mean_transfer_rate']:<20.4f} {summary['std_transfer_rate']:<20.4f} "
                    f"{summary['mean_same_transfer_rate']:<25.4f} {summary['std_same_transfer_rate']:<25.4f}"
                )

            # Save results to file
            save_path = (
                RESULTS_DIR
                / f"cifar10/evaluate_robustness/transferability_{attack_type}_eps_{epsilon}.json"
            )
            with open(save_path, "w") as f:
                json.dump(results[attack_type][epsilon], f, indent=2)
            print(f"\nResults saved to: {save_path}")

    # Save all results to a single file
    all_results_save_path = (
        RESULTS_DIR / "cifar10/evaluate_robustness/transferability_all_results.json"
    )
    with open(all_results_save_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nAll results saved to: {all_results_save_path}")


if __name__ == "__main__":
    main()
