# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import argparse
import os
from transformers import PaliGemmaForConditionalGeneration
import torch


def convert(output_path, tensor_parallel_size, use_te):
    device = "cuda"

    model_id = "google/paligemma-3b-pt-448"
    model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()

    model = model.to(device)

    print(model.config)
    for name, tensor in model.state_dict().items():
        if "vision_model" not in name:
            continue
        shape_str = "(" + ", ".join([str(x) for x in tensor.shape]) + ")"
        print(f"{name:<75} {shape_str:>20}")

    state_dict = model.state_dict()
    new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]

    def add_chunck_tensor(new_tensor, new_name, chunk_dim=None):
        if chunk_dim is None:
            new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
        else:
            new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)

        for i in range(tensor_parallel_size):
            # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage.
            new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()

            # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
            extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
            is_extra_state_layer = any([l in new_name for l in extra_state_layers])
            if use_te and is_extra_state_layer:
                layer = new_name.split(".")[-2]
                if layer in extra_state_layers:
                    extra_state_name = (
                        new_name[: new_name.rfind(".") + 1] + "_extra_state"
                    )  # Replace the weight name.
                    new_state_dicts[i]["model"][extra_state_name] = None

    for name, tensor in state_dict.items():
        if tensor.dtype == torch.float16:
            state_dict[name] = tensor.to(torch.float32)

    add_chunck_tensor(
        state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"],
        "position_embeddings.weight")
    add_chunck_tensor(
        state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"],
        "conv1.weight")
    add_chunck_tensor(
        state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"],
        "conv1.bias")

    head_dim = 72
    num_head = 16
    for layer_idx in range(27):
        origin_base = f"vision_tower.vision_model.encoder.layers.{layer_idx}"
        target_base = f"decoder.layers.{layer_idx}"

        for param_type in ["weight", "bias"]:
            # QKV
            q_proj_params = state_dict[f"{origin_base}.self_attn.q_proj.{param_type}"]
            k_proj_params = state_dict[f"{origin_base}.self_attn.k_proj.{param_type}"]
            v_proj_params = state_dict[f"{origin_base}.self_attn.v_proj.{param_type}"]
            # Do some tensor manipulation because megatron expect one tensor
            # projection for the QKV in the order
            # [(Q1, K1, V1), (Q2, K2, V2), ...] where Qi is the query of the
            # i-th head with dimension num_head.
            new_tensor = torch.concatenate([
                q_proj_params.view(num_head, head_dim, -1),
                k_proj_params.view(num_head, head_dim, -1),
                v_proj_params.view(num_head, head_dim, -1)], axis=1).view(
                    3*head_dim*num_head, -1)
            if param_type == "bias":
                new_tensor = new_tensor[:, 0]
            new_name = f"{target_base}.self_attention.linear_qkv.{param_type}"
            add_chunck_tensor(new_tensor, new_name, chunk_dim=0)
            # linear_proj
            add_chunck_tensor(
                state_dict[f"{origin_base}.self_attn.out_proj.{param_type}"],
                f"{target_base}.self_attention.linear_proj.{param_type}",
                chunk_dim=1 if param_type == "weight" else None)
            # layer_norm
            new_name = f"{target_base}.input_layernorm.{param_type}"
            if use_te:
                new_name = f"{target_base}.self_attention.linear_qkv.layer_norm_{param_type}"
            add_chunck_tensor(
                state_dict[f"{origin_base}.layer_norm1.{param_type}"],
                new_name)
            # FC 1
            add_chunck_tensor(
                state_dict[f"{origin_base}.mlp.fc1.{param_type}"],
                f"{target_base}.mlp.linear_fc1.{param_type}",
                chunk_dim=0)
            # FC 2
            add_chunck_tensor(
                state_dict[f"{origin_base}.mlp.fc2.{param_type}"],
                f"{target_base}.mlp.linear_fc2.{param_type}",
                chunk_dim=1 if param_type=="weight" else None)
            # layer_norm
            new_name = f"{target_base}.pre_mlp_layernorm.{param_type}"
            if use_te:
                new_name = f"{target_base}.mlp.linear_fc1.layer_norm_{param_type}"
            add_chunck_tensor(
                state_dict[f"{origin_base}.layer_norm2.{param_type}"],
                new_name)

    add_chunck_tensor(
        state_dict["vision_tower.vision_model.post_layernorm.weight"],
        "ln_post.weight")
    add_chunck_tensor(
        state_dict["vision_tower.vision_model.post_layernorm.bias"],
        "ln_post.bias")

    for i in range(tensor_parallel_size):
        output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}")
        os.makedirs(output_dir_tp)
        output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
        torch.save(new_state_dicts[i], output_path_tp)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="""
Convert SigLIP weights to megatron format.


Example usage:
python siglip_converter.py --tensor-parallel-size 4 --output google_paligemma_3b_pt_44_mcore_tp_4 --use-te

examples/multimodal/combine_mistral_clip.sh Mistral-7B-Instruct-v0.3-mcore-tp4 google_paligemma_3b_pt_44_mcore_tp_4 mistral_7b_instruct_v0p3_google_paligemma_3b_pt_44_mcore_tp_4
""",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--output", type=str, required=True, help="output directory for megatron state dict file(s)"
    )
    parser.add_argument(
        "--tensor-parallel-size", type=int, default=1, help="model tensor parallel size"
    )
    parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine")

    args = parser.parse_args()

    convert(args.output, args.tensor_parallel_size, args.use_te)

    print("done.")
