# %%
from typing import Callable
from dataclasses import dataclass
from symo.experiments.models import Activation
from optax.schedules import constant_schedule


@dataclass(frozen=True)
class ExperimentConfig:
    input_dim: int = 28**2
    device: str = "cpu"
    seed: int = 2025
    activation: Activation = "tanh"
    l2_reg: float = 1e-5
    batch_size: int = 60000
    test_batch_size: int = 10000


@dataclass(frozen=True)
class SymoConfig:
    num_epochs: int = 500
    momentum: float = 0.40
    decay: float = 0.997
    damping: float = 2.1e-10
    lr: float = 0.001


@dataclass(frozen=True)
class AdamConfig:
    num_epochs: int = 500
    lr: float = 0.00087
    b1: float = 0.9
    b2: float = 0.93
    eps: float = 5.3e-10


# From KFAC Paper
@dataclass(frozen=True)
class KFACConfig:
    num_epochs: int = 500
    learning_rate_schedule: Callable | None = None
    damping_schedule: Callable | None = None
    momentum_schedule: Callable | None = None
    inverse_update_period: int = 5
    damping_adaptation_interval: int = 5
    num_burnin_steps: int = 5
    curvature_ema: float = 0.95
    use_adaptive_damping: bool = True
    use_adaptive_learning_rate: bool = True
    use_adaptive_momentum: bool = True
    damping_adaptation_decay: float = 0.95
    initial_damping: float | None = 150.0
    min_damping: float = 1e-5
    max_damping: float = 1000
