from dataclasses import dataclass
from pathlib import Path
from typing import Literal

from omegaconf import DictConfig


@dataclass
class DatasetConfig:
    filepath: Path

    def __post_init__(self):
        self.filepath = Path(self.filepath)


@dataclass
class ModelConfig:
    hidden_dim: int
    ff_dim: int
    n_heads: int
    n_layers: int
    use_alibi: bool
    use_coords: bool
    use_random_ids: bool
    use_rope: bool
    use_ssmax: bool


@dataclass
class TrainerConfig:
    evaluation_batch_size: int
    evaluation_every: int
    evaluation_iters: int
    learning_rate: float
    learning_rate_min: float
    training_batch_size: int
    training_iters: int
    resume: Path | None

    def __post_init__(self):
        assert self.learning_rate >= self.learning_rate_min
        if self.resume is not None:
            self.resume = Path(self.resume)


@dataclass
class WandBConfig:
    entity: str
    group: str
    mode: Literal["online", "offline", "disabled"]


@dataclass
class MainConfig:
    model: ModelConfig
    trainer: TrainerConfig
    training_dataset: DatasetConfig
    evaluation_datasets: list[DatasetConfig]
    wandb: WandBConfig

    @classmethod
    def from_dict(cls, data: dict | DictConfig) -> "MainConfig":
        return cls(
            model=ModelConfig(**data["model"]),
            trainer=TrainerConfig(**data["trainer"]),
            training_dataset=DatasetConfig(data["data"]["training"]),
            evaluation_datasets=[DatasetConfig(path) for path in data["data"]["evaluation"]],
            wandb=WandBConfig(**data["wandb"]),
        )
