from dataclasses import dataclass, asdict
from typing import Literal

Activation = Literal["tanh", "relu", "gelu"]


# Model
@dataclass(frozen=True)
class MLP64Config:
    input_dim: int = 2
    output_dim: int = 1
    hidden_dims: tuple[int, ...] = (64,)
    use_bias: bool = True
    skip_every: int | None = None
    use_bias_last: bool = True
    activation: Activation = "tanh"


@dataclass(frozen=True)
class MLP64_64_48_48Config(MLP64Config):
    hidden_dims: tuple[int, ...] = (64, 64, 48, 48)


# Optimizers


@dataclass(frozen=True)
class Symo64Config:
    num_epochs: int = int(1e5)
    momentum: float = 0.95
    decay: float = 0.99
    damping: float = 0.0
    lr: float = 5e-3


@dataclass(frozen=True)
class Symo2_64Config:
    num_epochs: int = int(1e5)
    grad_beta: float = 0.9
    sigma_g_beta: float = 0.9
    grad_bias_corr: bool = False
    sigma_g_bias_corr: bool = False
    damping: float = 1e-8
    lr: float = 1


@dataclass(frozen=True)
class Adam64Config:
    num_epochs: int = int(1e5)
    lr: float = 1.692339e-3


@dataclass(frozen=True)
class SGD64Config:
    num_epochs: int = int(1e5)
    lr: float = 1.805015e-2
    momentum: float = 9.9e-1


@dataclass(frozen=True)
class Muon64Config:
    num_epochs: int = int(1e5)
    lr: float = 1e-2


@dataclass(frozen=True)
class Adam64_64_48_48Config:
    lr: float = 6.999994e-4
