# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
"""Dataclasses defining MuP-specific configuration for LLaMA models."""

from dataclasses import dataclass, field
from typing import Any

from torchtitan.models.llama3.model.args import (
    TransformerModelArgs as BaseTransformerModelArgs,
)


@dataclass
class MuPConfig:
    """Options controlling MuP/CompleteP behaviour."""

    mup_enabled: bool = True
    mup_disable_attention_scaling: bool = True
    mup_disable_hidden_lr_scaling: bool = False
    mup_width_multiplier: float = 1.0
    mup_input_alpha: float = 1.0
    mup_output_alpha: float = 1.0
    completep_depth_alpha_enabled: bool = True
    completep_depth_multiplier: float = 1.0
    completep_depth_alpha_exp: float = 1.0
    completep_eps_scaling_enabled: bool = True


@dataclass
class ModelInitConfig:
    """Initialization overrides for MuP-tuned models."""

    init_std: float = 0.02
    emb_init_std: float | None = 0.02
    output_mult: float | None = None


@dataclass
class TransformerModelArgs(BaseTransformerModelArgs):
    """Extended transformer arguments adding MuP-specific sections."""

    # muP / CompleteP
    use_embedding_norm: bool = True
    use_peri_norm: bool = True
    tie_word_embeddings: bool = True
    use_torch_layernorm: bool = True
    use_simple_silu_ffn: bool = False
    head_dim: int | None = None
    qk_norm: bool = True
    torch_layernorm_elementwise_affine: bool = True
    qk_norm_elementwise_affine: bool = True
    torch_layernorm_bias: bool = False
    qk_norm_bias: bool = False
    mup_config: dict[str, Any] = field(default_factory=dict)
    init_config: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self) -> None:
        """Instantiate strongly typed helpers from the raw configuration maps."""
        self.mup_config_obj = MuPConfig(**self.mup_config)
        self.init_config_obj = ModelInitConfig(**self.init_config)
        if self.head_dim is None:
            self.head_dim = self.dim // self.n_heads
