# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import re
from typing import Any

from torchtitan.protocols.state_dict_adapter import StateDictAdapter

from torchtitan.models.llama3.model import TransformerModelArgs


class Llama3StateDictAdapter(StateDictAdapter):
    from_hf_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        "model.layers.{}.self_attn.rotary_emb.inv_freq": None,
        "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    to_hf_map = {v: k for k, v in from_hf_map.items()}

    # HuggingFace permutation function (exact copy from their conversion script)
    @staticmethod
    def _permute(w, n_heads_arg, dim1=None, dim2=None):
        if dim1 is None:
            dim1 = w.shape[0]
        if dim2 is None:
            dim2 = w.shape[1]
        return (
            w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2)
            .transpose(1, 2)
            .reshape(dim1, dim2)
            .clone()
        )

    @staticmethod
    def _reverse_permute(w, n_heads_arg, dim1=None, dim2=None):
        if dim1 is None:
            dim1 = w.shape[0]
        if dim2 is None:
            dim2 = w.shape[1]
        return (
            w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2)
            .transpose(1, 2)
            .reshape(dim1, dim2)
        )

    @staticmethod
    def to_hf(
        state_dict: dict[str, Any], model_args: TransformerModelArgs
    ) -> dict[str, Any]:

        n_heads = model_args.n_heads
        n_kv_heads = (
            model_args.n_kv_heads if model_args.n_kv_heads is not None else n_heads
        )
        dim = model_args.dim
        head_dim = dim // n_heads
        hf_state_dict = {}

        for key, value in state_dict.items():
            if "layers" in key:
                abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
                layer_num = re.search(r"\d+", key).group(0)
                new_key = Llama3StateDictAdapter.to_hf_map[abstract_key]
                # We need to permute the weights in wq and wk layer in order to account for the difference between
                # the native Llama and huggingface RoPE implementation.
                if abstract_key == "layers.{}.attention.wq.weight":
                    value = Llama3StateDictAdapter._permute(value, n_heads)
                if abstract_key == "layers.{}.attention.wk.weight":
                    key_value_dim = head_dim * n_kv_heads
                    value = Llama3StateDictAdapter._permute(
                        value, n_kv_heads, key_value_dim, dim
                    )

                if new_key is None:
                    continue
                new_key = new_key.format(layer_num)
            else:
                new_key = Llama3StateDictAdapter.to_hf_map[key]

            hf_state_dict[new_key] = value
        return hf_state_dict

    @staticmethod
    def from_hf(
        hf_state_dict: dict[str, Any], model_args: TransformerModelArgs
    ) -> dict[str, Any]:
        n_heads = model_args.n_heads
        n_kv_heads = (
            model_args.n_kv_heads if model_args.n_kv_heads is not None else n_heads
        )
        dim = model_args.dim
        head_dim = dim // n_heads
        state_dict = {}

        for key, value in hf_state_dict.items():
            if "layers" in key:
                abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
                layer_num = re.search(r"\d+", key).group(0)
                new_key = Llama3StateDictAdapter.from_hf_map[abstract_key]
                # print(f"{new_key} in layer {layer_num}")

                # We need to permute the weights in wq and wk layer in order to account for the difference between
                # the native Llama and huggingface RoPE implementation.
                if abstract_key == "model.layers.{}.self_attn.q_proj.weight":
                    value = Llama3StateDictAdapter._reverse_permute(value, n_heads)
                if abstract_key == "model.layers.{}.self_attn.k_proj.weight":
                    key_value_dim = head_dim * n_kv_heads
                    value = Llama3StateDictAdapter._reverse_permute(
                        value, n_kv_heads, key_value_dim, dim
                    )

                if new_key is None:
                    continue
                new_key = new_key.format(layer_num)
            else:
                new_key = Llama3StateDictAdapter.from_hf_map[key]

            state_dict[new_key] = value
        return state_dict