import einops
import torch as t

from adversarial_superposition.constants import DEVICE
from adversarial_superposition.toy_models.utils import decision_boundaries
from adversarial_superposition.toy_models.utils.attack_toy_classifier import (
    attack_toy_classifier,
)
from adversarial_superposition.toy_models.utils.model import ClassificationConfig
from adversarial_superposition.toy_models.utils.orthogonal_model import (
    OrthogonalClassificationnModel,
)

cfg = ClassificationConfig(
    n_instances=8,
    n_features=7,
    n_hidden=3,
    n_classes=7,
)
p = 5 ** -t.linspace(0, 1, cfg.n_instances)
p = einops.rearrange(p, "instances -> instances ()")

gmodel = OrthogonalClassificationnModel(
    cfg=cfg,
    device=DEVICE,
    feature_probability=p,
    class_importance=None,
    loss_fn="ce",
    orthog_dim=0,
)
gmodel.optimize(steps=5_000, lr=1e-3)

attack_params = {
    "num_iter": 300,
    "alpha": 0.001,
    "epsilon": 0.2,
    "find_worst_case": True,
}

instance_idx = 5

attacks, _, _, _, _, _ = attack_toy_classifier(
    instance_idx=instance_idx,
    attack_params=attack_params,
    model=gmodel,
    attack_method="l2",
)

for pair_, attacks_ in attacks.items():
    print(f"{pair_}; {len(attacks_)}")

ex_idx = 0
ex_input = attacks[(5, 3)][ex_idx].input
ex_attack = attacks[(5, 3)][ex_idx].attacked_input

decision_boundaries.plot_adversarial_attack_mechanism_3d(
    gmodel,
    ex_input[0, :],
    ex_attack[0, :],
    instance_idx=instance_idx,
    batch_for_boundary=gmodel.generate_batch(10_000)[0],
    batch_labels_for_boundary=gmodel.generate_batch(10_000)[1][:, 7],
    fig_title=None,
    log_to_wandb=False,
    tag="adversarial_mechanism_plot",
)
