#  
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

####################################################################################################

#
# Note: If when running this conversion script you're getting an exception:
#     ModuleNotFoundError: No module named 'megatron.model.enums'
# you need to tell python where to find the clone of Megatron-LM, e.g.:
#
# cd /tmp
# git clone https://github.com/NVIDIA/Megatron-LM
# PYTHONPATH=/tmp/Megatron-LM python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py ...
#
# if you already have it cloned elsewhere, simply adjust the path to the existing path
#
# If the training was done using a Megatron-LM fork, e.g.,
# https://github.com/microsoft/Megatron-DeepSpeed/ then chances are that you need to have that one
# in your path, i.e., /path/to/Megatron-DeepSpeed/
#

import argparse
import os
import sys
sys.path.append(os.getcwd())
import re
import zipfile
from collections import OrderedDict
import torch
import json

from transformers.modeling_utils import shard_checkpoint

from models.qwen.configuration_qwen import QWenConfig
from models.qwen.tokenization_qwen import QWenTokenizer
from utils.data_utils import get_value_from_nested_dict, update_nested_dict


####################################################################################################


def recursive_print(name, val, spaces=0):
    # Format the message.
    if name is None:
        msg = None
    else:
        fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
        msg = fmt.format(name)

    # Print and recurse (if needed).
    if isinstance(val, dict):
        if msg is not None:
            print(msg)
        for k in val.keys():
            recursive_print(k, val[k], spaces + 2)
    elif isinstance(val, torch.Tensor):
        print(msg, ":", val.size())
    else:
        print(msg, ":", val)


def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
    # Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :]
    # for compatibility with later versions of NVIDIA Megatron-LM.
    # The inverse operation is performed inside Megatron-LM to read checkpoints:
    # https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
    # If param is the weight tensor of the self-attention block, the returned tensor
    # will have to be transposed one more time to be read by HuggingFace GPT2.
    input_shape = param.size()
    if checkpoint_version == 1.0:
        # version 1.0 stores [num_heads * hidden_size * num_splits, :]
        saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
        param = param.view(*saved_shape)
        param = param.transpose(0, 2)
        param = param.transpose(1, 2).contiguous()
    elif checkpoint_version >= 2.0:
        # other versions store [num_heads * num_splits * hidden_size, :]
        saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
        param = param.view(*saved_shape)
        param = param.transpose(0, 1).contiguous()
    param = param.view(*input_shape)
    return param


def load_checkpoints_from_multi_pt(ckpt_folder, save_loaded_ckpt=None):
    pt_files = [file_path for file_path in os.listdir(ckpt_folder) if file_path.startswith('mp_rank_')]

    params_cat_dim0 = [
        # "word_embeddings.weight", 
        # for old-version ckpt files: attention
        "attention.query_key_value.weight", "attention.query_key_value.bias", 
        # for 8k-version ckpt files: self_attention
        "self_attention.query_key_value.weight", "self_attention.query_key_value.bias",
        ]
    # each param in this group should be split first before concat
    param_split_cat_dim0 = [
        "mlp.dense_h_to_4h.bias", "mlp.dense_h_to_4h.weight",
    ]
    params_cat_dim1 = [
        # for old-version ckpt files: attention
        "attention.dense.weight",
        # for 8k-version ckpt files: self_attention
        "self_attention.dense.weight",
        "mlp.dense_4h_to_h.weight",
    ]
    params_copy = [
        "input_layernorm.weight", "input_layernorm.bias", 
        # for old-version ckpt files: attention
        "attention.attention_layernorm.weight", 
        "attention.attention_layernorm.bias", 
        "attention.dense.bias", 
        # for 8k-version ckpt files: self_attention
        "self_attention.attention_layernorm.weight", 
        "self_attention.attention_layernorm.bias", 
        "self_attention.dense.bias", 
        "post_attention_layernorm.weight", "post_attention_layernorm.bias",
        "mlp.dense_layernorm.weight", "mlp.dense_layernorm.bias",
        "mlp.dense_4h_to_h.bias", 
        "final_layernorm.weight", "final_layernorm.bias"
    ]
    params_ignore = [
        # for old-version ckpt files: attention
        "attention.rotary_emb.inv_freq",
        # for 8k-version ckpt files: self_attention
        "self_attention.rotary_emb.inv_freq"
    ]

    input_state_dicts = []
    for pt_file in pt_files:
        if pt_file.endswith('.pt'):
            print('loading {}. This may take a while...'.format(pt_file))
            pt_path = os.path.join(ckpt_folder, pt_file)
            input_state_dicts.append(torch.load(pt_path, map_location="cpu"))
        else:  # folder
            pt_path = os.path.join(ckpt_folder, pt_file + '/model_rng.pt')
            if os.path.isfile(pt_path):
                print('loading {}. This may take a while...'.format(pt_file + '/model_rng.pt'))
                input_state_dicts.append(torch.load(pt_path, map_location="cpu"))
    assert len(input_state_dicts) > 0, FileNotFoundError(
        'Cannot load ckpt files. Please make sure {} is not empty'.format(ckpt_folder))
    print('All checkpoints loaded. Merging...')

    base_input_state_dict = input_state_dicts[0]
    if len(input_state_dicts) == 1:
        return base_input_state_dict
    # merge word embeddings.
    module_key = "module" if "module" in base_input_state_dict else "model"
    key_list = [module_key, "language_model", "embedding", "word_embeddings", "weight"]
    word_embeddings = [get_value_from_nested_dict(
        input_state_dict, key_list) for input_state_dict in input_state_dicts]
    word_embeddings = torch.cat(word_embeddings, dim=0)
    update_nested_dict(base_input_state_dict, key_list, word_embeddings)
    param_shapes = [[".".join(key_list), word_embeddings.shape]]
    # merge output embeddings if possible
    try:
        key_list = [module_key, "language_model", "output_layer", "weight"]
        lm_heads = [get_value_from_nested_dict(
            input_state_dict, key_list) for input_state_dict in input_state_dicts]
        lm_heads = torch.cat(lm_heads, dim=0)
        update_nested_dict(base_input_state_dict, key_list, lm_heads)
        param_shapes = [[".".join(key_list), lm_heads.shape]]
    except KeyError:
        pass
    # The position embeddings.
    if getattr(base_input_state_dict['args'], "pos_emb", "learned") == "learned":
        key_list = [module_key, "language_model", "embedding", "position_embeddings", "weight"]
        raise NotImplementedError('I do not know how position_embeddings would be combined from files.')
    # The transformer.
    key_list = [module_key, "language_model", "encoder"]
    for key in base_input_state_dict[module_key]["language_model"]["encoder"].keys():
        full_key_list = key_list + [key]
        tensors = [get_value_from_nested_dict(
            input_state_dict, full_key_list) for input_state_dict in input_state_dicts]
        if key.startswith('layers'):  # layer.x.y.weight, layer.x.y.bias
            param_type = ".".join(key.split(".")[2:])
            if param_type in params_cat_dim0:
                tensors = torch.cat(tensors, dim=0)
            elif param_type in param_split_cat_dim0:
                tensors = [tensor.chunk(2, dim=0) for tensor in tensors]
                tensors = [tensor_tuple[0] for tensor_tuple in tensors] + [tensor_tuple[1] for tensor_tuple in tensors]
                tensors = torch.cat(tensors, dim=0)
            elif param_type in params_cat_dim1:
                tensors = torch.cat(tensors, dim=1)
            elif param_type in params_copy:
                if (tensors[0] - tensors[1]).abs().max() == 0:
                    tensors = tensors[0]
                else:
                    tensors = torch.mean(torch.stack(tensors), dim=0)
            elif param_type in params_ignore:
                continue
            else:
                print("Unexpected key: {}".format(key))
        elif key.startswith('final_layernorm'):
            param_type = key
            if param_type in params_copy:
                tensors = torch.mean(torch.stack(tensors), dim=0)
            else:
                print("Unexpected key: {}".format(key))
        else:
            print("Unexpected key: {}".format(key))

        update_nested_dict(base_input_state_dict, full_key_list, tensors)

        param_shapes.append([".".join(full_key_list), tensors.shape])

    # update input_state_dict["param_shapes"]
    base_input_state_dict["param_shapes"] = OrderedDict(param_shapes)

    if save_loaded_ckpt:
        save_path = os.path.join(ckpt_folder, "mp_rank_00_model_states.pt")
        torch.save(save_path, base_input_state_dict)

    return base_input_state_dict

####################################################################################################


def convert_megatron_checkpoint(args, input_state_dict, config):
    # The converted output model.
    output_state_dict = {}

    # old versions did not store training args
    ds_args = input_state_dict.get("args", None)
    if ds_args is not None:
        # do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
        # from pprint import pprint
        # pprint(vars(ds_args))

        if ds_args.activation == 'gelu':
            if ds_args.bias_gelu_fusion:
                activation_function = "gelu_fast"
            elif ds_args.openai_gelu:
                activation_function = "gelu_new"
            else:
                activation_function = "gelu"
        else:
            activation_function = ds_args.activation
        config.vocab_size = ds_args.padded_vocab_size
        config.n_positions = ds_args.max_position_embeddings
        config.n_embd = ds_args.hidden_size
        config.n_layer = ds_args.num_layers
        config.n_head = ds_args.num_attention_heads
        config.n_inner = ds_args.ffn_hidden_size  # decided by activation_function
        config.n_inner_factor = ds_args.make_ffn_dim_multiple_of
        config.attn_pdrop = ds_args.attention_dropout
        config.embd_pdrop = ds_args.hidden_dropout
        config.resid_pdrop = ds_args.hidden_dropout
        config.scale_attn_by_inverse_layer_idx = ds_args.apply_query_key_layer_scaling
        config.tie_word_embeddings = not getattr(ds_args, "untie_word_embeddings_and_output_weights", False)
        # special args of qwen
        # setattr(config, "sub_ln", getattr(ds_args, "apply_magneto", False))
        setattr(config, "pos_emb", ds_args.pos_emb)
        setattr(config, "skip_fc_bias", not getattr(ds_args, "bias", False))  # True
        setattr(config, "skip_qkv_bias", not getattr(ds_args, "qkv_bias", True))
        setattr(obj, "use_rmsnorm", getattr(ds_args, "use_rmsnorm", True))
        # print(config)

    # The number of heads.
    heads = config.n_head
    # The hidden_size per head.
    hidden_size_per_head = config.n_embd // config.n_head
    # Megatron-LM checkpoint version
    if "checkpoint_version" in input_state_dict.keys():
        checkpoint_version = input_state_dict["checkpoint_version"]
    else:
        checkpoint_version = 0.0

    # The model.
    model = input_state_dict.get("model", None) or input_state_dict.get("module", None)
    assert model, KeyError("input_state_dict does not have key `model` or `module`.")
    # The language model.
    lm = model["language_model"]
    # The embeddings.
    embeddings = lm["embedding"]

    # The word embeddings.
    word_embeddings = embeddings["word_embeddings"]["weight"]
    # Truncate the embedding table to vocab_size rows.
    word_embeddings = word_embeddings[: config.vocab_size, :]
    output_state_dict["transformer.wte.weight"] = word_embeddings.to(torch.float16).clone()

    # The position embeddings.
    n_positions = config.n_positions
    if getattr(config, "pos_emb", "learned") == "learned":
        pos_embeddings = embeddings["position_embeddings"]["weight"]
        # Read the causal mask dimension (seqlen). [max_sequence_length, hidden_size]
        n_positions = pos_embeddings.size(0)
        if n_positions != config.n_positions:
            raise ValueError(
                f"pos_embeddings.max_sequence_length={n_positions} and config.n_positions={config.n_positions} don't match"
            )
        # Store the position embeddings.
        output_state_dict["transformer.wpe.weight"] = pos_embeddings.to(torch.float16).clone()

    # The transformer.
    transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]

    # The regex to extract layer names.
    layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # The simple map of names for "automated" rules.
    megatron_to_transformers = {
        "attention.dense": ".attn.c_proj.",
        "self_attention.dense": ".attn.c_proj.",
        "mlp.dense_h_to_4h": ".mlp.c_fc.",
        "mlp.dense_4h_to_h": ".mlp.c_proj.",
    }

    # for 3b and 13b checkpoints, need to combine w1 and w2 into dense_h_to_4h
    if "layers.0.mlp.w1.weight" in transformer:
        for layer_id in range(config.n_layer):
            w1 = transformer.pop("layers.{}.mlp.w1.weight".format(layer_id))
            w2 = transformer.pop("layers.{}.mlp.w2.weight".format(layer_id))
            b1 = transformer.pop("layers.{}.mlp.w1.bias".format(layer_id))
            b2 = transformer.pop("layers.{}.mlp.w2.bias".format(layer_id))
            transformer["layers.{}.mlp.dense_h_to_4h.weight".format(layer_id)] = torch.cat((w1, w2), dim=0)
            transformer["layers.{}.mlp.dense_h_to_4h.bias".format(layer_id)] = torch.cat((b1, b2), dim=0)

    # Extract the layers.
    # base = ds_args.rotary_emb_base  # used in rotary_emb
    # dim = None
    # inv_freq = None
    for key, val in transformer.items():
        val = val.to(torch.float16)
        # Match the name.
        m = layer_re.match(key)

        # Stop if that's not a layer
        if m is None:
            continue

        # The index of the layer.
        layer_idx = int(m.group(1))
        # The name of the operation.
        op_name = m.group(2)
        # Is it a weight or a bias?
        weight_or_bias = m.group(3)

        # The name of the layer.
        layer_name = f"transformer.h.{layer_idx}"

        # For layernorm(s), simply store the layer norm.
        if op_name.endswith("layernorm"):
            if op_name.startswith("input"):
                ln_name = "ln_1"
            elif op_name.startswith("attention"):
                ln_name = "attn.sub_ln"
            elif op_name.startswith("mlp"):
                ln_name = "mlp.sub_ln"
            elif op_name.startswith("post_attention"):
                ln_name = "ln_2"
            else:
                print("Unmapped LN layer: {}".format(op_name))
            new_name = layer_name + "." + ln_name + "." + weight_or_bias
            assert new_name not in output_state_dict
            output_state_dict[new_name] = val.clone()

        # Transpose the QKV matrix.
        elif (
            op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
        ) and weight_or_bias == "weight":
            # Insert a tensor of 1x1xDxD bias.
            # causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.float16)).view(
            #     1, 1, n_positions, n_positions
            # )
            # new_name = layer_name + ".attn.bias"
            # assert new_name not in output_state_dict
            # output_state_dict[new_name] = causal_mask.clone()

            # Insert a "dummy" tensor for masked_bias.
            # masked_bias = torch.tensor(-1e4, dtype=torch.float16)
            # new_name = layer_name + ".attn.masked_bias"
            # assert new_name not in output_state_dict
            # output_state_dict[new_name] = masked_bias.clone()

            out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
            # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D.
            out_val = out_val.transpose(0, 1).contiguous().clone()
            # Store.
            new_name = layer_name + ".attn.c_attn.weight"
            assert new_name not in output_state_dict
            output_state_dict[new_name] = out_val.clone()

        # Transpose the bias.
        elif (
            op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
        ) and weight_or_bias == "bias":
            out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
            # Store. No change of shape.
            new_name = layer_name + ".attn.c_attn.bias"
            assert new_name not in output_state_dict
            output_state_dict[new_name] = out_val.clone()

        # Transpose the weights.
        elif weight_or_bias == "weight":
            out_name = megatron_to_transformers[op_name]
            new_name = layer_name + out_name + "weight"
            assert new_name not in output_state_dict
            output_state_dict[new_name] = val.transpose(0, 1).clone()

        # Copy the bias.
        elif weight_or_bias == "bias":
            out_name = megatron_to_transformers[op_name]
            new_name = layer_name + out_name + "bias"
            assert new_name not in output_state_dict
            output_state_dict[new_name] = val.clone()

        elif op_name.endswith("rotary_emb"):
            # new_name = layer_name + ".attn.rotary_emb.inv_freq"
            # assert new_name not in output_state_dict
            # # recalculate val
            # if dim is None:
            #     dim = val.shape[0] * 2
            #     inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
            # val = inv_freq.clone().to(torch.float32)
            # output_state_dict[new_name] = val
            continue  # we calculate rotary_emb.inv_freq during model initialization
        else:
            print("Parameter unmapped: {}".format(op_name))

    # DEBUG.
    assert config.n_layer == layer_idx + 1

    # The final layernorm.
    output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"].to(torch.float16).clone()
    output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"].to(torch.float16).clone()

    # For LM head, transformers' may tie the matrix to weight embeddings.
    if config.tie_word_embeddings:
        output_state_dict["lm_head.weight"] = output_state_dict["transformer.wte.weight"]
    else:
        output_state_dict["lm_head.weight"] = lm['output_layer']['weight']

    # It should be done!
    # every layer: .attn.bias and .attn.masked_bias
    # two or three embeddings: transformer.wte.weight, transformer.wpe.weight, lm_head.weight
    # assert len(output_state_dict) - len(transformer) >= config.n_layer * 2 + 2
    return output_state_dict


####################################################################################################


def main():
    # Create the argument parser.
    parser = argparse.ArgumentParser()
    parser.add_argument("--print-checkpoint-structure", action="store_true")
    parser.add_argument(
        "--path_to_checkpoint",
        type=str,
        help="Path to the checkpoint file (.zip archive, direct .pt file, or folder containing multiple .pt files.)",
    )
    parser.add_argument(
        "--config_file",
        default="",
        type=str,
        help="An optional config json file describing the pre-trained model.",
    )
    parser.add_argument(
        "--output_folder",
        default="",
        type=str,
        help="An optional config json file describing the pre-trained model.",
    )
    args = parser.parse_args()

    # Extract the basename.
    basename = os.path.dirname(args.path_to_checkpoint)
    output_folder = args.output_folder or basename

    # Load the model.
    # the .zip is very optional, let's keep it for backward compatibility
    print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
    if args.path_to_checkpoint.endswith(".zip"):
        with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
            with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
                input_state_dict = torch.load(pytorch_dict, map_location="cpu")
    elif args.path_to_checkpoint.endswith(".pt"):
        input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
    elif os.path.isdir(args.path_to_checkpoint):
        input_state_dict = load_checkpoints_from_multi_pt(args.path_to_checkpoint)
    else:
        raise ValueError("Cannot parse {}.".format(args.path_to_checkpoint))

    ds_args = input_state_dict.get("args", None)

    # Read the config, or default to the model released by NVIDIA.
    if args.config_file == "":
        # Spell out all parameters in case the defaults change.
        config = QWenConfig(
            activation_function="geglu",
            attn_pdrop=0.1,
            bos_token_id=50256,
            embd_pdrop=0.1,
            eos_token_id=50256,
            initializer_range=0.02,
            layer_norm_epsilon=1e-5,
            n_embd=2560,
            n_head=32,
            n_inner=None,
            n_layer=32,
            n_positions=2048,
            pos_emb="rotary",
            reorder_and_upcast_attn=False,
            resid_pdrop=0.1,
            scale_attn_by_inverse_layer_idx=False,
            scale_attn_weights=True,
            sub_ln=True,
            skip_fc_bias=True,
            use_cache=True,
            vocab_size=65408,
        )
    else:
        config = QWenConfig.from_json_file(args.config_file)

    config.architectures = ["QWenLMHeadModel"]

    # Convert.
    print("Converting")
    output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)

    # Print the structure of converted state dict.
    if args.print_checkpoint_structure:
        recursive_print(None, output_state_dict)

    # Add tokenizer class info to config
    vocab_file = "/mnt/workspace/copilot/gpt2-zhcn3-v4.json"
    merge_file = "/mnt/workspace/copilot/gpt2-zhcn3-v4.bpe"
    tokenizer = QWenTokenizer(vocab_file, merge_file, errors='replace')
    tokenizer_class = type(tokenizer).__name__
    config.tokenizer_class = tokenizer_class

    # Store the config to file.
    print("Saving config")
    config.save_pretrained(output_folder)

    # Save tokenizer based on args
    print(f"Adding {tokenizer_class} tokenizer files")
    tokenizer.save_pretrained(output_folder)

    # Store the state_dict to file.
    # output_checkpoint_file = os.path.join(output_folder, "pytorch_model.bin")
    # print(f'Saving checkpoint to "{output_checkpoint_file}"')
    # torch.save(output_state_dict, output_checkpoint_file)
    # shard models so that each checkpoint is not too large
    shards, index = shard_checkpoint(output_state_dict, max_shard_size="10GB")
    for key, shard in shards.items():
        output_checkpoint_file = os.path.join(output_folder, key)
        print(f'Saving checkpoint to "{output_checkpoint_file}"')
        torch.save(shard, output_checkpoint_file)
    if index:
        output_meta_file = os.path.join(output_folder, "pytorch_model.bin.index.json")
        with open(output_meta_file, 'w', encoding='utf-8') as f: 
            f.write(json.dumps(index, indent=2))
    
    print('Saving complete.')

####################################################################################################

if __name__ == "__main__":
    main()

####################################################################################################
