# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""State dict adapter bridging MuP LLaMA3 checkpoints and TorchTitan."""

from typing import Any

import torch

from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter
from .mup_args import TransformerModelArgs


class Llama3MuPStateDictAdapter(Llama3StateDictAdapter):
    """State dict adapter for Llama3 MuP model.

    Inherits from the standard Llama3StateDictAdapter and extends the mapping
    to handle MuP-specific features:

    1. embedding_norm: Optional RMSNorm layer applied to embeddings
    2. post_attn_norm/post_ffn_norm: Peri-normalization layers (standard in μP)
    3. output.weight: Optional - not present when tie_word_embeddings=True

    Weight Tying Compatibility:
    ---------------------------
    When tie_word_embeddings=True:
        - TorchTitan state_dict contains: tok_embeddings.weight (no output.weight)
        - HuggingFace format uses: model.embed_tokens.weight (no lm_head.weight)
        - The base adapter already handles this correctly via from_hf_map

    When tie_word_embeddings=False:
        - TorchTitan state_dict contains: tok_embeddings.weight AND output.weight
        - HuggingFace format has: model.embed_tokens.weight AND lm_head.weight
        - Both weights are converted using the standard mapping

    Why This Works:
    ---------------
    The base Llama3StateDictAdapter iterates through keys in the state_dict
    and converts them. If output.weight doesn't exist (weight tying case),
    it simply won't be in the state_dict and won't be converted. This is
    correct behavior - no special handling needed.

    MuP Extensions:
    ---------------
    We extend from_hf_map to add mappings for:
    - embedding_norm.weight (applied to embeddings after lookup)
    - post_attn_norm.weight (peri-normalization after attention)
    - post_ffn_norm.weight (peri-normalization after FFN)

    These tensors are preserved in both directions so exporting to HuggingFace
    keeps the MuP-specific behavior intact.
    """

    def __init__(
        self,
        model_args: TransformerModelArgs,
        hf_assets_path: str | None,
    ) -> None:
        # Initialize base adapter with standard Llama3 mappings
        super().__init__(model_args, hf_assets_path)

    def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
        """Convert the MuP state dict to a format that already matches the LightEval model naming."""
        hf_state: dict[str, Any] = {}
        tied_embed = state_dict.get("tok_embeddings.weight")

        for key, value in state_dict.items():
            if key.startswith("model."):
                hf_key = key
            elif key == "output.weight":
                hf_key = "model.output.weight"
            else:
                hf_key = f"model.{key}"
            tensor = value

            if hf_key == "model.output.weight" and isinstance(value, torch.Tensor):
                # Avoid shared storage when weights are tied.
                if isinstance(tied_embed, torch.Tensor) and tied_embed.data_ptr() == value.data_ptr():
                    tensor = value.clone()
                lm_head_tensor = tensor.clone() if isinstance(tensor, torch.Tensor) else tensor
                hf_state.setdefault("lm_head.weight", lm_head_tensor)

            hf_state[hf_key] = tensor

        return hf_state

    def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
        """Load from the MuP-aware HuggingFace layout by removing the leading `model.` prefix."""
        state_dict: dict[str, Any] = {}
        for key, value in hf_state_dict.items():
            if key.startswith("model."):
                native_key = key[len("model.") :]
                state_dict[native_key] = value
            elif key.startswith("lm_head."):
                native_key = f"output.{key[len('lm_head.'):]}"
                state_dict.setdefault(native_key, value)
        return state_dict
