import importlib
import json

import numpy as np
import torch

from adversarial_superposition.constants import DEVICE, MODEL_DIR, RESULTS_DIR
from adversarial_superposition.modulo.utils import generate_attacks
from adversarial_superposition.modulo.utils.attack import run_attack
from adversarial_superposition.modulo.utils.generate_attacks import (
    generate_random_attack,
    test_attack,
)
from adversarial_superposition.modulo.utils.helpers import create_one_hot_pair
from adversarial_superposition.modulo.utils.utils import Config, get_model

experiment_key = "acc55bfa"

with open(RESULTS_DIR / f"toy_models/{experiment_key}/config.json", "r") as f:
    config = json.load(f)
    config = Config().from_dict(config)
    print(
        f"Using the model config from: {RESULTS_DIR / f'toy_models/{experiment_key}/config.json'}"
    )

model = get_model(config)

model.load_state_dict(
    torch.load(
        MODEL_DIR / f"toy_models/{experiment_key}/last_run_saved_model_checkpoints.pt",
        map_location=DEVICE,
    )[9950]
)

attacks, fig, axes = run_attack(
    experiment_key,
    config,
    grokking_epoch=9_950,
    attack_type="l2",
    verbose=True,
    save_path=RESULTS_DIR / f"toy_models/{experiment_key}/modulo_attack.pdf",
)

digit_1 = 0
digit_2 = 0
one_hot = create_one_hot_pair(digit_1, digit_2).unsqueeze(0)

attack_digit_1 = attacks[: config.input_size, digit_1, digit_2]
attack = attacks[:, digit_1, digit_2]

orig_pred = model(one_hot).argmax()
print(f"Original prediction for {digit_1} + {digit_2} = {orig_pred}")

perturbed_pred = model(attack.unsqueeze(0)).argmax()
print(f"Perturbed prediction for {digit_1} + {digit_2} = {perturbed_pred}")

print("The L2 norm of the attack")
print(torch.norm(attack.unsqueeze(0)))


p_mod = 113
k_values = [9, 23, 28, 41, 52, 53, 53]
l2_budget = 2.0

test_pairs = [(0, 0)]

for digit1, digit2 in test_pairs:
    original_input = create_one_hot_pair(digit1, digit2)
    target = (digit1 + digit2) % p_mod

    # Get original prediction
    with torch.no_grad():
        original_pred = model(original_input.unsqueeze(0)).argmax(dim=1)
    print(f"\nTesting attacks on input: {digit1} + {digit2} = {target}")
    print(f"Original prediction: {original_pred.item()}")

    importlib.reload(generate_attacks)
    inverse_attack = generate_attacks.generate_inverse_interference_attack(
        p_mod, k_values, plot_type="cos", l2_budget=l2_budget, plot=True
    )
    inverse_attack = torch.cat([inverse_attack, torch.zeros(113)]).to(DEVICE)
    inverse_success, inverse_pred = test_attack(
        model, original_input, inverse_attack, target
    )
    print("\nInverse Interference Attack:")
    print(f"Attack succeeded: {inverse_success}")
    print(f"Predicted class: {inverse_pred}")
    print(f"L2 norm: {torch.norm(inverse_attack).item():.6f}")

    # Generate and test random attack
    random_attack = generate_random_attack(p_mod, l2_budget)
    random_attack = torch.cat([random_attack, torch.zeros(113)]).to(DEVICE)
    random_success, random_pred = test_attack(
        model, original_input, random_attack, target
    )
    print("\nRandom Attack:")
    print(f"Attack succeeded: {random_success}")
    print(f"Predicted class: {random_pred}")
    print(f"L2 norm: {torch.norm(random_attack).item():.6f}")

    # Generate and test uniform attack (equal perturbation across all elements)
    uniform_attack = torch.ones(p_mod) * (l2_budget / np.sqrt(p_mod))
    uniform_attack = torch.cat([uniform_attack, torch.zeros(113)]).to(DEVICE)
    uniform_success, uniform_pred = test_attack(
        model, original_input, uniform_attack, target
    )
    print("\nUniform Attack:")
    print(f"Attack succeeded: {uniform_success}")
    print(f"Predicted class: {uniform_pred}")
    print(f"L2 norm: {torch.norm(uniform_attack).item():.6f}")

    # Generate and test sine wave attack (distribute budget along a sine wave)
    x = torch.arange(p_mod, dtype=torch.float32)
    sine_wave = torch.sin(2 * np.pi * x / p_mod)
    sine_attack = sine_wave * (l2_budget / torch.norm(sine_wave))
    sine_attack = torch.cat([sine_attack, torch.zeros(113)]).to(DEVICE)
    sine_success, sine_pred = test_attack(model, original_input, sine_attack, target)
    print("\nSine Wave Attack:")
    print(f"Attack succeeded: {sine_success}")
    print(f"Predicted class: {sine_pred}")
    print(f"L2 norm: {torch.norm(sine_attack).item():.6f}")

    # Generate and test single-point attack (concentrated budget on target class)
    single_point_attack = torch.zeros(p_mod)
    single_point_attack[target] = l2_budget
    single_point_attack = torch.cat([single_point_attack, torch.zeros(113)]).to(DEVICE)
    single_point_success, single_point_pred = test_attack(
        model, original_input, single_point_attack, target
    )
    print("\nSingle-Point Attack:")
    print(f"Attack succeeded: {single_point_success}")
    print(f"Predicted class: {single_point_pred}")
    print(f"L2 norm: {torch.norm(single_point_attack).item():.6f}")
