import json
import os
import re
import shutil
import sys
from pathlib import Path
from typing import Optional
import glob
from safetensors.torch import load_file as load_safetensors_file
import torch
from requests.exceptions import HTTPError

from FlashInfer.utils import ModelArgs

wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))


def hf_download(out_dir: str, repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
    """Download model from HuggingFace Hub."""
    print(f"Downloading {repo_id} to {out_dir} ...")
    if os.path.exists(out_dir):
        print(f"Directory {out_dir} already exists. Skipping download.")
        return
    
    from huggingface_hub import snapshot_download
    os.makedirs(f"{out_dir}", exist_ok=True)
    try:
        snapshot_download(repo_id, local_dir=f"{out_dir}", local_dir_use_symlinks=False, token=hf_token)
    except HTTPError as e:
        if e.response.status_code == 401:
            print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
        else:
            raise e


def cleanup_original_files(checkpoint_dir: Path) -> None:
    """Remove original .bin and .safetensors files after conversion."""
    print("Cleaning up original model files...")
    
    bin_files = list(checkpoint_dir.glob("*.bin"))
    for bin_file in bin_files:
        print(f"Removing {bin_file}")
        bin_file.unlink()
    
    safetensors_files = list(checkpoint_dir.glob("*.safetensors"))
    for safetensors_file in safetensors_files:
        print(f"Removing {safetensors_file}")
        safetensors_file.unlink()
    
    index_files = list(checkpoint_dir.glob("*.index.json"))
    for index_file in index_files:
        if "model" in index_file.name:
            print(f"Removing {index_file}")
            index_file.unlink()
    
    print("Cleanup completed!")


@torch.inference_mode()
def convert_hf_checkpoint(
    *,
    checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"),
    model_name: Optional[str] = None,
) -> None:
    """Convert HuggingFace checkpoint to model.pth format."""
    if model_name is None:
        model_name = checkpoint_dir.name

    config = ModelArgs.from_name(model_name)
    print(f"Model config {config.__dict__}")

    model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
    model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
    model_safetensors = checkpoint_dir / "model.safetensors"
    model_pytorch = checkpoint_dir / "pytorch_model.bin"
    model_map_json = None
    model_file = None
   
    try:
      assert model_map_json_safetensors.is_file()
      model_map_json = model_map_json_safetensors
      print(f"Found safetensors index at {model_map_json_safetensors}")
    except AssertionError:
      print(f"{model_map_json_safetensors} not found")
    if model_map_json is None:
      try:
        assert model_map_json_pytorch.is_file()
        model_map_json = model_map_json_pytorch
        print(f"Found pytorch index at {model_map_json_pytorch}")
      except AssertionError:
        print(f"{model_map_json_pytorch} not found")
   
    if model_map_json is None:
       try:
           assert model_safetensors.is_file()
           model_file = model_safetensors
           print(f"Found safetensors weights at {model_safetensors}")
       except AssertionError:
           print(f"{model_safetensors} not found")
    
    if model_map_json is None and model_file is None:
        try:
           assert model_pytorch.is_file()
           model_file = model_pytorch
           print(f"Found pytorch weights at {model_pytorch}")
        except AssertionError:
           print(f"{model_pytorch} not found, can't find any weights or index.")
           exit()

    if model_map_json != None:
        with open(model_map_json) as json_map:
            bin_index = json.load(json_map)

    weight_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
        'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    if "qwen2" in model_name.lower() or "r1-distill-qwen" in model_name.lower():
        weight_map.update({
            "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
            "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
            "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
        })
    if "qwen3" in model_name.lower():
        weight_map.update({
            "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight",
            "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight",
        })
    
    if model_map_json != None:
        bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

    def permute(w, n_head, qk_norm=False):
        if len(w.shape) == 2:
            dim = config.dim
            return (
                w.view(n_head, 2, config.head_dim // 2, dim)
                .transpose(1, 2)
                .reshape(config.head_dim * n_head, dim)
            )
        else:
            if qk_norm:
                return w.view(2, config.head_dim // 2).transpose(0, 1).reshape(config.head_dim)
            else:
                return w.view(n_head, 2, config.head_dim // 2).transpose(1, 2).reshape(config.head_dim * n_head)

    merged_result = {}
    if model_map_json != None:
        for file in sorted(bin_files):
            if "safetensors" in str(file):
                state_dict = load_safetensors_file(str(file), device="cpu")
                merged_result.update(state_dict)
            else:
                state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
                merged_result.update(state_dict)
    else:
        if "safetensors" in str(model_file):
            state_dict = load_safetensors_file(str(model_file), device="cpu")
            merged_result.update(state_dict)
        else:
            state_dict = state_dict = torch.load(str(model_file), map_location="cpu", mmap=True, weights_only=True)
            merged_result.update(state_dict)
    
    for key in merged_result.keys():
        print(f"Key: {key}, Shape: {merged_result[key].shape}")

    final_result = {}
    for key, value in merged_result.items():
        if "layers" in key:
            abstract_key = re.sub(r'(\d+)', '{}', key)
            layer_num = re.search(r'\d+', key).group(0)
            new_key = weight_map[abstract_key]
            if new_key is None:
                continue
            new_key = new_key.format(layer_num)
        else:
            new_key = weight_map[key]

        final_result[new_key] = value

    for key in tuple(final_result.keys()):
        if "wq" in key:
            q = final_result[key]
            k = final_result[key.replace("wq", "wk")]
            v = final_result[key.replace("wq", "wv")]
            q = permute(q, config.n_head)
            k = permute(k, config.n_local_heads)
            final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
            del final_result[key]
            del final_result[key.replace("wq", "wk")]
            del final_result[key.replace("wq", "wv")]
        if "q_norm" in key:
            q_norm = final_result[key]
            k_norm = final_result[key.replace("q_norm", "k_norm")]
            q_norm = permute(q_norm, config.n_head, qk_norm=True)
            k_norm = permute(k_norm, config.n_local_heads, qk_norm=True)
            final_result[key] = q_norm
            final_result[key.replace("q_norm", "k_norm")] = k_norm

    print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
    torch.save(final_result, checkpoint_dir / "model.pth")
    if 'llama-3' in model_name.lower():
        original_dir = checkpoint_dir / "original"
        tokenizer_model = original_dir / "tokenizer.model"
        tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
        print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
        shutil.copy(tokenizer_model, tokenizer_model_tiktoken)


def prepare_model(repo_id: str, out_dir: str, hf_token: Optional[str] = None, 
                  model_name: Optional[str] = None, cleanup: bool = True) -> None:
    """Download model from HuggingFace Hub and convert to model.pth format."""
    checkpoint_dir = Path(out_dir)
    
    hf_download(out_dir, repo_id, hf_token)
    
    convert_hf_checkpoint(checkpoint_dir=checkpoint_dir, model_name=model_name)
    
    if cleanup:
        cleanup_original_files(checkpoint_dir)
    
    print(f"Model preparation completed! Final model saved at: {checkpoint_dir / 'model.pth'}")


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Download and convert HuggingFace model.')
    parser.add_argument('--repo_id', type=str, required=True, help='Repository ID to download from.')
    parser.add_argument('--out_dir', type=str, required=True, help='Output directory to save the model.')
    parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')
    parser.add_argument('--model_name', type=str, default=None, help='Model name override.')
    parser.add_argument('--no_cleanup', action='store_true', help='Skip cleanup of original files.')

    args = parser.parse_args()
    cleanup = not args.no_cleanup
    
    prepare_model(args.repo_id, args.out_dir, args.hf_token, args.model_name, cleanup)