import random

import einops
import numpy as np
import torch as t

from adversarial_superposition.constants import DEVICE
from adversarial_superposition.toy_models.utils.model import (
    ClassificationConfig,
    ClassificationModel,
)
from adversarial_superposition.toy_models.utils.utils import plots

# Correlated pair Data
seed = 0
random.seed(seed)
np.random.seed(seed)
t.random.manual_seed(seed)

cfg = ClassificationConfig(
    n_instances=8,
    n_features=8,
    n_hidden=2,
    n_classes=8,
    feature_generation_mode="mixed",
    n_correlated_feature_pairs=4,
)
p = t.linspace(1, 0, cfg.n_instances)
p = einops.rearrange(p, "instances -> instances ()")
model = ClassificationModel(
    cfg=cfg,
    device=DEVICE,
    feature_probability=p,
    loss_fn="ce",
)
model.optimize(steps=5_000, lr=1e-3, optimizer="adam")
print(model.test_accuracy()[0])
plots(model, classification=True)


# Fully Correlated Data
random.seed(seed)
np.random.seed(seed)
t.random.manual_seed(seed)

cfg = ClassificationConfig(
    n_instances=8,
    n_features=8,
    n_hidden=2,
    n_classes=8,
    feature_generation_mode="class_correlated",
    n_correlated_feature_pairs=0,
    cycle_base_prob=0.1,
    cycle_amplitude=0.9,
    cycle_sparsity=0.0,
)
p = t.linspace(1, 0, cfg.n_instances)
p = einops.rearrange(p, "instances -> instances ()")
model1 = ClassificationModel(
    cfg=cfg,
    device=DEVICE,
    feature_probability=p,
    loss_fn="ce",
)
model1.optimize(steps=5_000, lr=1e-3, optimizer="adam")

print(model1.test_accuracy()[0])
plots(model1, classification=True)
