import argparse
import time
import sys
import os
import re

import torch
import torch.distributed
from tqdm import tqdm

from megatron.core import mpu
from megatron.training.arguments import parse_args, validate_args
from megatron.training.global_vars import set_global_variables
from megatron.training.checkpointing import load_checkpoint, save_checkpoint
from megatron.training.initialize import _set_random_seed, _initialize_distributed
from megatron.core.enums import ModelType
from tasks.idefics2_vision_model.train_idefics2 import model_provider
from tools.px_ckpt_conv import px_ckpt_conv
from gpatch.core.device_type import is_wxacc2


def create_mlm_model(run_args, only_create=False):
    print('creating MLM model...')
    t0 = time.time()
    from vision_model import Idefics2VisionConfig
    config_path = os.path.join(run_args.hf_load_dir, 'config.json')
    hf_config = Idefics2VisionConfig.from_pretrained(config_path)
    print(hf_config)

    sys.argv = [
        'placeholder.py',
        '--no-load-optim',
        '--no-load-rng',
        '--no-save-optim',
        '--no-save-rng',
        '--no-initialization',
        '--use-cpu-initialization',
        '--micro-batch-size',
        '1',
        '--save-interval',
        '1',
        '--tensor-model-parallel-size',
        '1',
        '--pipeline-model-parallel-size',
        '1',
        '--normalization',
        'LayerNorm',
        '--no-masked-softmax-fusion',
        '--attention-dropout',
        '0',
        '--hidden-dropout',
        '0',
        '--position-embedding-type',
        'rope',
    ]
    args = parse_args()
    args.bf16 = run_args.bf16
    args.fp16 = run_args.fp16
    args.clip_grad = 1.0
    args.seed = 42
    args.model_type = ModelType.encoder_or_decoder
    args.load = run_args.megatron_load_dir
    args.save = run_args.megatron_save_dir

    args.num_layers = hf_config.num_hidden_layers
    args.hidden_size = hf_config.hidden_size
    args.num_attention_heads = hf_config.num_attention_heads
    args.ffn_hidden_size = hf_config.intermediate_size
    args.max_position_embeddings = 4096
    args.encoder_seq_length = 4096
    args.tokenizer_type = 'HfAutoTokenizer'

    args.use_mcore_models = True
    args.use_dist_ckpt = True
    args.dist_ckpt_format = run_args.dist_ckpt_format
    args.world_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size
    args = validate_args(args)
    set_global_variables(args, build_tokenizer=False)
    mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
    mpu.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
    mpu.set_virtual_pipeline_model_parallel_world_size(args.virtual_pipeline_model_parallel_size)

    _initialize_distributed()
    _set_random_seed(args.seed)
    print(f"params_dtype: {args.params_dtype}")
    model = model_provider(True, True).to(args.params_dtype)
    print(f"model: {model}")

    if not only_create:
        base_name = os.path.basename(args.load)
        args.load = args.load[0:-len(base_name)]
        numbers = base_name
        if base_name != "release":
            numbers = re.findall(r'\d+', base_name)[0]
            iteration, _ = load_checkpoint([model], None, None, checkpoint_step=int(numbers))
        else:
            iteration, _ = load_checkpoint([model], None, None)
        print(f"Loaded ckpt from iteration {iteration}")

    return model, hf_config


def create_hf_model(run_args, only_create=False):
    print('creating HF model...')
    t0 = time.time()
    torch_dtype = torch.float
    if run_args.bf16:
        torch_dtype = torch.bfloat16
    if run_args.fp16:
        torch_dtype = torch.float16

    if only_create:
        from vision_model import Idefics2VisionConfig, Idefics2VisionTransformer
        config_path = os.path.join(run_args.hf_load_dir, 'config.json')
        config = Idefics2VisionConfig.from_pretrained(config_path)
        hf_model = Idefics2VisionTransformer(config=config).to(torch_dtype)
    else:
        from vision_model import Idefics2VisionTransformer
        hf_model = Idefics2VisionTransformer.from_pretrained(run_args.hf_load_dir).to(torch_dtype)

    t1 = time.time()
    print(f'HF model created t1 - t0 {t1 - t0}')
    return hf_model


def set_lm2hf_attn_state(run_args, model_config, lm_layer, hf_layer):
    lm_attn = lm_layer.self_attn
    hf_attn = hf_layer.self_attn

    num_query_head = model_config.num_attention_heads
    num_kv_head = num_query_head
    assert model_config.hidden_size % model_config.num_attention_heads == 0
    assert num_query_head % num_kv_head == 0
    dim = model_config.hidden_size // model_config.num_attention_heads
    hidden_size = model_config.hidden_size
    total_dim = 2 * dim + (dim * num_query_head // num_kv_head)

                
    qkv_proj = lm_attn.qkv_proj.weight.reshape((num_kv_head, total_dim, -1))
    assert qkv_proj.size(2) == hidden_size
    q_proj = qkv_proj.narrow(1, 0, dim * num_query_head // num_kv_head).reshape(
        (dim * num_query_head, -1)
    )
    k_proj = qkv_proj.narrow(1, dim * num_query_head // num_kv_head,
                             dim).reshape((dim * num_kv_head, -1))
    v_proj = qkv_proj.narrow(1, dim * num_query_head // num_kv_head + dim,
                             dim).reshape((dim * num_kv_head, -1))
    hf_attn.q_proj.weight.copy_(q_proj)
    hf_attn.k_proj.weight.copy_(k_proj)
    hf_attn.v_proj.weight.copy_(v_proj)

              
    mlm_bias = lm_attn.qkv_proj.bias.reshape((num_kv_head, total_dim))
    q_bias = mlm_bias.narrow(1, 0,
                             dim * num_query_head // num_kv_head).reshape((dim * num_query_head))
    k_bias = mlm_bias.narrow(1, dim * num_query_head // num_kv_head,
                             dim).reshape((dim * num_kv_head))
    v_bias = mlm_bias.narrow(1, dim * num_query_head // num_kv_head + dim,
                             dim).reshape((dim * num_kv_head))
    hf_attn.q_proj.bias.data.copy_(q_bias)
    hf_attn.k_proj.bias.data.copy_(k_bias)
    hf_attn.v_proj.bias.data.copy_(v_bias)

              
    hf_attn.out_proj.weight.copy_(lm_attn.out_proj.weight)
    hf_attn.out_proj.bias.copy_(lm_attn.out_proj.bias)


def set_hf2lm_attn_state(run_args, model_config, lm_layer, hf_layer):
    lm_attn = lm_layer.self_attn
    hf_attn = hf_layer.self_attn

    num_query_head = model_config.num_attention_heads
    num_kv_head = num_query_head
    assert model_config.hidden_size % model_config.num_attention_heads == 0
    assert num_query_head % num_kv_head == 0
    dim = model_config.hidden_size // model_config.num_attention_heads
    hidden_size = model_config.hidden_size

    lm_attn.qkv_proj.weight.copy_(
        torch.cat(
            [
                hf_attn.q_proj.weight.reshape(
                    (num_kv_head, dim * num_query_head // num_kv_head, hidden_size)
                ),
                hf_attn.k_proj.weight.reshape((num_kv_head, dim, hidden_size)),
                hf_attn.v_proj.weight.reshape((num_kv_head, dim, hidden_size)),
            ],
            dim=1,
        ).reshape(-1, hidden_size)
    )
    lm_attn.qkv_proj.bias.copy_(
        torch.cat(
            [
                hf_attn.q_proj.bias.reshape((num_kv_head, dim * num_query_head // num_kv_head)),
                hf_attn.k_proj.bias.reshape((num_kv_head, dim)),
                hf_attn.v_proj.bias.reshape((num_kv_head, dim)),
            ],
            dim=1,
        ).reshape(-1)
    )
    lm_attn.out_proj.weight.copy_(hf_attn.out_proj.weight)
    lm_attn.out_proj.bias.copy_(hf_attn.out_proj.bias)


def set_lm2hf_mlp_state(lm_layer, hf_layer):
    lm_mlp = lm_layer.mlp
    hf_mlp = hf_layer.mlp

    hf_mlp.fc1.weight.copy_(lm_mlp.fc1.weight)
    hf_mlp.fc1.bias.copy_(lm_mlp.fc1.bias)
    hf_mlp.fc2.weight.copy_(lm_mlp.fc2.weight)
    hf_mlp.fc2.bias.copy_(lm_mlp.fc2.bias)


def set_hf2lm_mlp_state(lm_layer, hf_layer):
    lm_mlp = lm_layer.mlp
    hf_mlp = hf_layer.mlp

    lm_mlp.fc1.weight.copy_(hf_mlp.fc1.weight)
    lm_mlp.fc1.bias.copy_(hf_mlp.fc1.bias)
    lm_mlp.fc2.weight.copy_(hf_mlp.fc2.weight)
    lm_mlp.fc2.bias.copy_(hf_mlp.fc2.bias)


                                              
def convert_mlm_to_hf(run_args, model_config, mlm_model, hf_model, save_ckpt=True):
    t0 = time.time()

    print('copying parameters...')
               
    hf_model.embeddings.patch_embedding.weight.copy_(mlm_model.embeddings.patch_embedding.weight)
    hf_model.embeddings.patch_embedding.bias.copy_(mlm_model.embeddings.patch_embedding.bias)
    hf_model.embeddings.position_embedding.weight.copy_(
        mlm_model.embeddings.position_embedding.weight
    )

             
    num_layers = model_config.num_hidden_layers

    for layer_idx in tqdm(range(num_layers), "encoder layer states"):
        hf_layer = hf_model.encoder.layers[layer_idx]
        lm_layer = mlm_model.encoder.layers[layer_idx]
              
        set_lm2hf_attn_state(run_args, model_config, lm_layer, hf_layer)
        if is_wxacc2():
                         
            hf_layer.layer_norm1.weight.copy_(lm_layer.layer_norm1.weight)
            hf_layer.layer_norm1.bias.copy_(lm_layer.layer_norm1.bias)
                         
            hf_layer.layer_norm2.weight.copy_(lm_layer.layer_norm2.weight)
            hf_layer.layer_norm2.bias.copy_(lm_layer.layer_norm2.bias)
        else:
                         
            hf_layer.layer_norm1.weight.copy_(lm_layer.self_attn.qkv_proj.layer_norm_weight)
            hf_layer.layer_norm1.bias.copy_(lm_layer.self_attn.qkv_proj.layer_norm_bias)
                         
            hf_layer.layer_norm2.weight.copy_(lm_layer.mlp.fc1.layer_norm_weight)
            hf_layer.layer_norm2.bias.copy_(lm_layer.mlp.fc1.layer_norm_bias)
             
        set_lm2hf_mlp_state(lm_layer, hf_layer)

                    
    if hasattr(mlm_model, "post_layernorm") and mlm_model.post_layernorm is not None:
        hf_model.post_layernorm.weight.copy_(mlm_model.post_layernorm.weight)
        hf_model.post_layernorm.bias.copy_(mlm_model.post_layernorm.bias)

    if not save_ckpt:
        return

    t1 = time.time()
    print('HF model saving pretrained...')
    hf_model.save_pretrained(run_args.hf_save_dir, safe_serialization=False)

                 
    config_path = os.path.join(run_args.hf_load_dir, 'config.json')
    cmd = f"cp {config_path} {run_args.hf_save_dir}"
    os.system(cmd)
    py_path = os.path.join(run_args.hf_load_dir, '*.py')
    cmd = f"cp {py_path} {run_args.hf_save_dir}"
    os.system(cmd)
            
    bin_name = os.path.join(run_args.hf_save_dir, 'pytorch_model.bin')
    pth_name = os.path.join(run_args.hf_save_dir, 'vision_model.pth')
    cmd = f"mv {bin_name} {pth_name}"
    os.system(cmd)

    t2 = time.time()
    print(f'converted MLM ckpt to HF ckpt successfully copy:{t1 - t0}s save:{t2 - t1}s')


def convert_hf_to_mlm(run_args, model_config, hf_model, mlm_model, with_save=True):
                                            
                                              
                                  
                             
                                    
               
    mlm_model.embeddings.patch_embedding.weight.copy_(hf_model.embeddings.patch_embedding.weight)
    mlm_model.embeddings.patch_embedding.bias.copy_(hf_model.embeddings.patch_embedding.bias)
    mlm_model.embeddings.position_embedding.weight.copy_(
        hf_model.embeddings.position_embedding.weight
    )

             
    num_layers = model_config.num_hidden_layers
    for layer_idx in tqdm(range(num_layers), "encoder layer states"):
        hf_layer = hf_model.encoder.layers[layer_idx]
        lm_layer = mlm_model.encoder.layers[layer_idx]
              
        set_hf2lm_attn_state(run_args, model_config, lm_layer, hf_layer)
        if is_wxacc2():
                         
            lm_layer.layer_norm1.weight.copy_(hf_layer.layer_norm1.weight)
            lm_layer.layer_norm1.bias.copy_(hf_layer.layer_norm1.bias)
                         
            lm_layer.layer_norm2.weight.copy_(hf_layer.layer_norm2.weight)
            lm_layer.layer_norm2.bias.copy_(hf_layer.layer_norm2.bias)
        else:
                         
            lm_layer.self_attn.qkv_proj.layer_norm_weight.copy_(hf_layer.layer_norm1.weight)
            lm_layer.self_attn.qkv_proj.layer_norm_bias.copy_(hf_layer.layer_norm1.bias)
                         
            lm_layer.mlp.fc1.layer_norm_weight.copy_(hf_layer.layer_norm2.weight)
            lm_layer.mlp.fc1.layer_norm_bias.copy_(hf_layer.layer_norm2.bias)
             
        set_hf2lm_mlp_state(lm_layer, hf_layer)

                    
    if hasattr(mlm_model, "post_layernorm") and mlm_model.post_layernorm is not None:
        mlm_model.post_layernorm.weight.copy_(hf_model.post_layernorm.weight)
        mlm_model.post_layernorm.bias.copy_(hf_model.post_layernorm.bias)

    if not with_save:
        return

    save_checkpoint(1, [lm_model], None, None, num_floating_point_operations_so_far=0)

    if torch.distributed.get_rank() == 0:
        old_name = os.path.join(run_args.megatron_save_dir, "iter_0000001")
        new_name = os.path.join(run_args.megatron_save_dir, "release")
        latesest_file = os.path.join(
            run_args.megatron_save_dir, "latest_checkpointed_iteration.txt"
        )
        os.rename(old_name, new_name)
        with open(latesest_file, 'w') as f:
            f.write('release')
    print("successfully convert hf ckpt to megatron ckpt")


if __name__ == "__main__":
    run_args = px_ckpt_conv.get_run_args()
    torch.set_grad_enabled(False)

    if run_args.convert_way == "hf_to_mlm":
        lm_model, model_config = create_mlm_model(run_args, only_create=True)
        hf_model = create_hf_model(run_args, only_create=False)
        convert_hf_to_mlm(
            run_args=run_args, model_config=model_config, mlm_model=lm_model, hf_model=hf_model
        )
    elif run_args.convert_way == "mlm_to_hf":
        lm_model, model_config = create_mlm_model(run_args, only_create=False)
        hf_model = create_hf_model(run_args, only_create=True)
                                                      
        convert_mlm_to_hf(
            run_args=run_args, model_config=model_config, mlm_model=lm_model, hf_model=hf_model
        )
    else:
        raise NotImplementedError(f"convert way {run_args.convert_way} is not supported")
