import copy
import json
import os

import torch
import torch.onnx as tonnx
from torch.utils.data import DataLoader, SubsetRandomSampler

import wandb
from adversarial_superposition.constants import DEVICE, MODEL_DIR, RESULTS_DIR
from adversarial_superposition.modulo.utils.utils import Config, get_dataset, get_model


if __name__ == "__main__":
    experiment_key = ""
    post_grokking_epoch = 9_950

    with open(RESULTS_DIR / f"toy_models/{experiment_key}/config.json", "r") as f:
        config = json.load(f)
        config = Config().from_dict(config)

    model = get_model(config)
    post_grokking_model = copy.deepcopy(model)
    post_grokking_model.load_state_dict(
        torch.load(
            MODEL_DIR
            / f"toy_models/{experiment_key}/last_run_saved_model_checkpoints.pt",
            map_location=DEVICE,
        )[post_grokking_epoch]
    )

    config.train_fraction = 0.3
    train_dset, test_dset = get_dataset(config)
    if config.full_batch:
        config.batch_size = len(train_dset)

    val_len = int(0.4 * len(test_dset))
    test_len = len(test_dset) - val_len
    train_loader = DataLoader(train_dset, batch_size=1024)
    val_loader = DataLoader(
        test_dset, batch_size=1024, sampler=SubsetRandomSampler(range(0, val_len))
    )
    test_loader = DataLoader(
        test_dset,
        batch_size=1024,
        sampler=SubsetRandomSampler(range(val_len, val_len + test_len)),
    )

    ###### Save and load model to/from ONNX ######

    dummy = torch.ones((1, 226)).to(DEVICE)
    tonnx.export(
        post_grokking_model,
        dummy,
        MODEL_DIR / "digits.onnx",
        verbose=False,
        opset_version=16,
    )
    parser = ONNXParser(MODEL_DIR / "digits.onnx")
    model = parser.to_pytorch()
    model.set_device(use_gpu=True)
    model.eval()

    ##### Robustify Setup #####

    mean = torch.Tensor([0.0]).to(DEVICE)
    std = torch.Tensor([1.0]).to(DEVICE)

    network_funcs = ClassificationNetworkFuncs(num_outputs=113)
    pert_func = BoxPerturbationFuncs(mean, std)
    rob_checker = RobustnessChecker(pert_func, network_funcs)

    ###### Check accuracy #####

    # Test a range of epsilon values
    epsilon_values = [1e-4, 5e-4, 1e-3, 2e-3]
    results = []

    print("Testing robustness across different epsilon values:")
    print("Epsilon\tRobust %\tCorrectly Classified %")
    print("-" * 50)

    for eps in epsilon_values:
        att_epsilon = torch.tensor([eps]).to(DEVICE)
        pre_res = rob_checker(model, test_loader, eps=att_epsilon)
        robust_percent = 100 * pre_res[0] / pre_res[2]
        correct_percent = 100 * pre_res[1] / pre_res[2]
        results.append((eps, robust_percent, correct_percent))
        print(f"{eps:.1e}\t{robust_percent:.2f}%\t{correct_percent:.2f}%")

    # Save results to a file
    results_path = RESULTS_DIR / f"toy_models/{experiment_key}/robustness_results.json"
    os.makedirs(results_path.parent, exist_ok=True)
    with open(results_path, "w") as f:
        json.dump(
            {
                "epsilon_values": epsilon_values,
                "robust_percentages": [r[1] for r in results],
                "correct_percentages": [r[2] for r in results],
            },
            f,
            indent=2,
        )

    ###### Robustify ######

    # Define hyperparameter grid
    hyperparams = {
        "train_epsilon": [1e-3],
        "learning_rate": [1e-3],
        "weight_decay": [1e-4],
        "n_epochs": [500],
    }

    # Define attack epsilons for evaluation
    # attack_epsilons = [1e-4, 5e-4, 1e-3, 2e-3]
    attack_epsilons = [1e-4]

    # Create results directory
    results_dir = RESULTS_DIR / f"robust_modulo/{experiment_key}_hyperparam_search"
    os.makedirs(results_dir, exist_ok=True)

    # Store all results
    all_results = []

    # Initialize wandb project
    wandb_project = "robust-modulo"
    wandb_group = f"modulo_robust_{experiment_key}"

    # Grid search over hyperparameters
    for train_eps in hyperparams["train_epsilon"]:
        for lr in hyperparams["learning_rate"]:
            for wd in hyperparams["weight_decay"]:
                for n_epochs in hyperparams["n_epochs"]:
                    # Calculate warmup epochs as 20% of total epochs
                    n_warmup_epochs = int(0.2 * n_epochs)

                    run_name = f"train_eps{train_eps:.1e}_lr{lr:.1e}_wd{wd:.1e}_epochs{n_epochs}"
                    print(f"\nTraining with: {run_name}")

                    # Initialize wandb run
                    wandb.init(
                        project=wandb_project,
                        group=wandb_group,
                        name=run_name,
                        config={
                            "train_epsilon": train_eps,
                            "learning_rate": lr,
                            "weight_decay": wd,
                            "n_epochs": n_epochs,
                            "n_warmup_epochs": n_warmup_epochs,
                            "experiment_key": experiment_key,
                        },
                    )

                    # Create a fresh copy of the model for each training run
                    model_copy = copy.deepcopy(model)

                    # Use training epsilon for robust training
                    train_epsilon = torch.tensor([train_eps]).to(DEVICE)
                    pert_scheduler = SShapedWarmUpRobScheduler(
                        n_warmup_epochs, train_epsilon
                    )
                    optimiser = torch.optim.Adam(
                        model_copy.parameters(), lr=lr, weight_decay=wd
                    )
                    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                        optimiser,
                        T_max=(n_epochs - n_warmup_epochs) * len(train_loader),
                    )

                    # Create unique run directory
                    run_path = results_dir / run_name
                    os.makedirs(run_path, exist_ok=True)

                    robustnet = RobustNet(
                        model=model_copy,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        test_loader=test_loader,
                        outdir_path=str(run_path),
                    )

                    # Train the model
                    robustnet.train(
                        optimiser=optimiser,
                        pert_scheduler=pert_scheduler,
                        network_funcs=network_funcs,
                        perturbation_funcs=pert_func,
                        total_epochs=n_epochs,
                        warm_up_epochs=n_warmup_epochs,
                        lr_scheduler=lr_scheduler,
                    )

                    # Evaluate robustness across different attack epsilons
                    attack_results = {}
                    for attack_eps in attack_epsilons:
                        att_epsilon = torch.tensor([attack_eps]).to(DEVICE)
                        post_res = rob_checker(model_copy, test_loader, eps=att_epsilon)
                        robust_percent = 100 * post_res[0] / post_res[2]
                        correct_percent = 100 * post_res[1] / post_res[2]

                        attack_results[f"attack_eps_{attack_eps:.1e}"] = {
                            "robust_accuracy": robust_percent,
                            "correct_accuracy": correct_percent,
                            "total_processed": post_res[2],
                        }

                        # Log metrics to wandb
                        wandb.log(
                            {
                                f"attack_eps_{attack_eps:.1e}_robust_accuracy": robust_percent,
                                f"attack_eps_{attack_eps:.1e}_correct_accuracy": correct_percent,
                                f"attack_eps_{attack_eps:.1e}_total_processed": post_res[
                                    2
                                ],
                                "run_path": str(run_path),
                            }
                        )

                        print(
                            f"Attack epsilon {attack_eps:.1e}: robust={robust_percent:.2f}%, correct={correct_percent:.2f}%"
                        )

                    # Store results
                    result = {
                        "train_epsilon": train_eps,
                        "learning_rate": lr,
                        "weight_decay": wd,
                        "attack_results": attack_results,
                        "run_path": str(run_path),
                    }
                    all_results.append(result)

                    # Finish wandb run
                    wandb.finish()

    # Save all results
    with open(results_dir / "all_results.json", "w") as f:
        json.dump(all_results, f, indent=2)

    # Find best configuration based on average robust accuracy across attack epsilons
    def get_avg_robust_accuracy(result):
        robust_accuracies = [
            r["robust_accuracy"] for r in result["attack_results"].values()
        ]
        return sum(robust_accuracies) / len(robust_accuracies)

    best_result = max(all_results, key=get_avg_robust_accuracy)
    print("\nBest configuration:")
    print(f"Training epsilon: {best_result['train_epsilon']:.1e}")
    print(f"Learning rate: {best_result['learning_rate']:.1e}")
    print(f"Weight decay: {best_result['weight_decay']:.1e}")
    print("\nAttack results:")
    for eps, metrics in best_result["attack_results"].items():
        print(
            f"{eps}: robust={metrics['robust_accuracy']:.2f}%, correct={metrics['correct_accuracy']:.2f}%"
        )
    print(f"\nResults saved in: {best_result['run_path']}")
