from dataclasses import dataclass, field
from multiprocessing import cpu_count
from pathlib import Path
from typing import List, Literal


@dataclass(frozen=True)
class ModelConfig:
    n_hidden: int = 2048
    embedding_activation: str = "silu"
    embedding_dimension: int = 256
    n_heads: int = 8
    n_attention_layers: int = 3
    output_dimension: int = 1
    meshcnn_hidden_layers: List[int] = field(default_factory=lambda: [32, 64, 128, 256])
    meshcnn_conv_radius: int = 1
    meshcnn_activation: str = "relu"
    n_fourier_features: int = 2048
    # no longer used
    # use_meshcnn: bool = True


@dataclass(frozen=True)
class LearningRate:
    warmup_steps: int = 10000
    max_value: float = 1e-2
    decay_steps: int = 100_000
    minimum_value: float = 1e-6


@dataclass(frozen=True)
class DataConfig:
    random_augmentation: bool = True
    training_data: str = "data/training/*"
    validation_data: str = "data/validation/helmholtz/*.vtu"
    source_multiplier: float = 1.0


@dataclass(frozen=True)
class Config:
    checkpoint: str = "ckpt"
    batch_size: int = 1024
    epochs: int = 1000
    steps_per_epoch: int = 100
    n_workers: int = cpu_count() - 1
    optimizer: str = "adam"
    optimizer_args: dict = field(default_factory=dict)
    grad_norm: float = 100.0
    equation: Literal["helmholtz", "poisson", "reaction_diffusion"] = "helmholtz"
    boundary_condition: Literal["neumann", "dirichlet"] = "neumann"
    learning_rate: LearningRate = LearningRate()
    grad_accumulation_steps: int = 8
    model: ModelConfig = ModelConfig()
    data: DataConfig = DataConfig()

    def __repr__(self):
        import yaml
        from dataclass_wizard import asdict

        return yaml.dump(asdict(self))


def get_config(path: Path) -> Config:
    if path is not None:
        import yaml
        from dataclass_wizard import fromdict

        with open(path, "r") as fp:
            cfg = yaml.safe_load(fp)
        return fromdict(Config, cfg)
    else:
        return Config()
