import einops
import torch as t

from adversarial_superposition.constants import DEVICE
from adversarial_superposition.toy_models.utils.model import (
    ClassificationConfig,
    ClassificationModel,
)
from adversarial_superposition.toy_models.utils.utils import plots

cfg = ClassificationConfig(
    n_instances=8,
    n_features=6,
    n_hidden=2,
    n_classes=6,
)
p = 50 ** -t.linspace(0, 1, cfg.n_instances)
p = einops.rearrange(p, "instances -> instances ()")
model_ce = ClassificationModel(
    cfg=cfg,
    device=DEVICE,
    feature_probability=p,
    loss_fn="ce",
)
model_ce.optimize(steps=5_000, lr=1e-3, optimizer="adam")
print(model_ce.test_accuracy())
plots(model_ce, classification=True)
