import einops
import torch as t

from adversarial_superposition.constants import DEVICE
from adversarial_superposition.toy_models.utils.attack_toy_classifier import (
    attack_toy_classifier,
)
from adversarial_superposition.toy_models.utils.decision_boundaries import (
    plot_attack_transitions,
)
from adversarial_superposition.toy_models.utils.model import (
    ClassificationConfig,
    ClassificationModel,
)
from adversarial_superposition.toy_models.utils.utils import plots

instance_idx = 7

cfg = ClassificationConfig(
    n_instances=8,
    n_features=7,
    n_hidden=2,
    n_classes=7,
)
p = 50 ** -t.linspace(0, 1, cfg.n_instances)
p = einops.rearrange(p, "instances -> instances ()")
gmodel = ClassificationModel(
    cfg=cfg,
    device=DEVICE,
    feature_probability=p,
    loss_fn="ce",
)
gmodel.optimize(steps=5_000, lr=1e-3, optimizer="adam")
print(gmodel.test_accuracy())
plots(gmodel, classification=True)

attack_params = {
    "num_iter": 300,
    "alpha": 0.01,
    "epsilon": 0.2,
}

all_attacks = {}
for target_class in range(gmodel.cfg.n_classes):
    attacks, _, _, _, _, _ = attack_toy_classifier(
        instance_idx=instance_idx,
        attack_params=attack_params,
        model=gmodel,
        attack_method="l2",
        target_class=target_class,
    )

    for pair_, attacks_ in attacks.items():
        if pair_ not in all_attacks:
            all_attacks[pair_] = []
        all_attacks[pair_].extend(attacks_)

plot_attack_transitions(
    model=gmodel,
    attack_lookup=all_attacks,
    instance_idx=instance_idx,
    batch_for_boundary=gmodel.generate_batch(10_000)[0],
    batch_labels_for_boundary=gmodel.generate_batch(10_000)[1][:, 7],
    fig_title="Attack Transitions Analysis",
)
