import torch
import os
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, GPTNeoXForCausalLM
from transformers import GPT2TokenizerFast # 用于处理 GPT2 兼容的 BPE tokenizer

# --- 用户需要修改的信息 ---
CHECKPOINT_PATH = "/root/checkpoints/my_training/nomask410m/global_step400/"
MODEL_STATE_FILENAME = "mp_rank_00_model_states.pt" # 你的模型状态文件名
CONFIG_FILE = "/root/checkpoints/my_training/nomask410m/global_step400/configs/pythia-410m.yml" # 你的配置文件
OUTPUT_DIR = "/root/trainbin1/nomask410m/nomask410m400" # 输出目录

# --- 新增：你已有的 tokenizer 路径 ---
# 请根据你实际的 tokenizer 类型和路径进行调整
# 如果它是一个 Hugging Face tokenizer 目录，例如 "/path/to/my_hf_tokenizer/"
# 如果它是一个 tokenizer.json 文件，你需要提供其完整路径
EXISTING_TOKENIZER_PATH = "/root/trainbin1/hftokenizer" # <--- **请在这里修改为你已有的 tokenizer 路径**

# --- 不要修改下面的代码，除非你知道你在做什么 ---

def convert_megatron_to_hf_custom():
    # 1. 加载你的 PyTorch 检查点
    full_checkpoint_file = os.path.join(CHECKPOINT_PATH, MODEL_STATE_FILENAME)
    print(f"Loading custom checkpoint from: {full_checkpoint_file}")
    
    try:
        loaded_state_dict = torch.load(full_checkpoint_file, map_location="cpu")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return

    if 'module' in loaded_state_dict:
        model_weights_state_dict = loaded_state_dict['module']
        print("Successfully extracted 'module' from the checkpoint.")
    else:
        print("Error: 'module' key not found in the loaded checkpoint. Please check your checkpoint structure.")
        return

    # 2. 从 .yml 配置加载模型参数，并构建 Hugging Face 配置
    print(f"Loading configuration from: {CONFIG_FILE}")
    with open(CONFIG_FILE, 'r') as f:
        neox_config = yaml.safe_load(f)

    hf_config = AutoConfig.from_pretrained("EleutherAI/pythia-160m") 
    
    hf_config.vocab_size = neox_config.get('padded-vocab-size', 50304)
    hf_config.hidden_size = neox_config.get('hidden-size', 1024)
    hf_config.num_hidden_layers = neox_config.get('num-layers', 24) 
    hf_config.num_attention_heads = neox_config.get('num-attention-heads', 16) 
    hf_config.max_position_embeddings = neox_config.get('seq-length', 2048) 
    hf_config.intermediate_size = neox_config.get('ffn-hidden-size', hf_config.hidden_size * 4) # 默认 4 倍
    hf_config.rotary_pct = neox_config.get('rotary-pct', 1.0)
    hf_config.rotary_emb_base = neox_config.get('rotary-emb-base', 10000)
    hf_config.layernorm_epsilon = neox_config.get('layernorm-epsilon', 1e-5)
    hf_config.gradient_checkpointing = neox_config.get('gradient-checkpointing', False)
    hf_config.position_embedding_type = "rotary"
    
    print("Hugging Face Config created:")
    print(hf_config)

    # 3. 创建 Hugging Face 模型实例
    print(f"Creating Hugging Face model instance of type {GPTNeoXForCausalLM.__name__}...")
    hf_model = GPTNeoXForCausalLM(hf_config)
    
    # 4. 权重映射
    print("Mapping and loading weights...")
    hf_state_dict = hf_model.state_dict()

    # 词嵌入层
    hf_state_dict["gpt_neox.embed_in.weight"] = model_weights_state_dict["sequential.0.word_embeddings.weight"]

    # Transformer 块
    num_layers = hf_config.num_hidden_layers
    for i in range(num_layers): # i 是 HF 模型的层索引 (0, 1, ..., num_layers-1)
        my_layer_idx = i + 2 

        # Input Layernorm
        hf_state_dict[f"gpt_neox.layers.{i}.input_layernorm.weight"] = model_weights_state_dict[f"sequential.{my_layer_idx}.input_layernorm.weight"]
        hf_state_dict[f"gpt_neox.layers.{i}.input_layernorm.bias"] = model_weights_state_dict[f"sequential.{my_layer_idx}.input_layernorm.bias"]
        
        # Attention
        hf_state_dict[f"gpt_neox.layers.{i}.attention.query_key_value.weight"] = model_weights_state_dict[f"sequential.{my_layer_idx}.attention.query_key_value.weight"]
        hf_state_dict[f"gpt_neox.layers.{i}.attention.query_key_value.bias"] = model_weights_state_dict[f"sequential.{my_layer_idx}.attention.query_key_value.bias"]
        hf_state_dict[f"gpt_neox.layers.{i}.attention.dense.weight"] = model_weights_state_dict[f"sequential.{my_layer_idx}.attention.dense.weight"]
        hf_state_dict[f"gpt_neox.layers.{i}.attention.dense.bias"] = model_weights_state_dict[f"sequential.{my_layer_idx}.attention.dense.bias"]
        
        # Post-Attention Layernorm
        hf_state_dict[f"gpt_neox.layers.{i}.post_attention_layernorm.weight"] = model_weights_state_dict[f"sequential.{my_layer_idx}.post_attention_layernorm.weight"]
        hf_state_dict[f"gpt_neox.layers.{i}.post_attention_layernorm.bias"] = model_weights_state_dict[f"sequential.{my_layer_idx}.post_attention_layernorm.bias"]
        
        # MLP
        hf_state_dict[f"gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight"] = model_weights_state_dict[f"sequential.{my_layer_idx}.mlp.dense_h_to_4h.weight"]
        hf_state_dict[f"gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias"] = model_weights_state_dict[f"sequential.{my_layer_idx}.mlp.dense_h_to_4h.bias"]
        hf_state_dict[f"gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight"] = model_weights_state_dict[f"sequential.{my_layer_idx}.mlp.dense_4h_to_h.weight"]
        hf_state_dict[f"gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias"] = model_weights_state_dict[f"sequential.{my_layer_idx}.mlp.dense_4h_to_h.bias"]
        
    # 最终的 Layernorm (对应你的 sequential.9)
    hf_state_dict["gpt_neox.final_layer_norm.weight"] = model_weights_state_dict["sequential.27.norm.weight"]
    hf_state_dict["gpt_neox.final_layer_norm.bias"] = model_weights_state_dict["sequential.27.norm.bias"]

    # LM Head (输出层) (对应你的 sequential.10)
    hf_state_dict["embed_out.weight"] = model_weights_state_dict["sequential.28.final_linear.weight"]

    print("Loading mapped state_dict into Hugging Face model...")
    hf_model.load_state_dict(hf_state_dict, strict=True)
    print("Weights loaded successfully.")

    # 5. 保存 Hugging Face 模型和 tokenizer
    print(f"Saving Hugging Face model to {OUTPUT_DIR}")
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    hf_model.save_pretrained(OUTPUT_DIR)

    # --- 修改后的 tokenizer 加载逻辑 ---
    print(f"Attempting to load tokenizer from {EXISTING_TOKENIZER_PATH} and save to {OUTPUT_DIR}...")
    try:
        # AutoTokenizer 会尝试从路径智能加载 tokenizer
        tokenizer = AutoTokenizer.from_pretrained(EXISTING_TOKENIZER_PATH)
        tokenizer.save_pretrained(OUTPUT_DIR)
        print("Tokenizer loaded from existing path and saved successfully.")
    except Exception as e:
        print(f"Error loading/saving tokenizer from EXISTING_TOKENIZER_PATH: {e}")
        print("Falling back to EleutherAI/pythia-14m tokenizer (might not be exact match if your vocab changed).")
        try:
            tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
            tokenizer.save_pretrained(OUTPUT_DIR)
            print("Fallback tokenizer saved successfully.")
        except Exception as e_fallback:
            print(f"Error saving fallback tokenizer: {e_fallback}")
            print("Please manually ensure a tokenizer is saved in the output directory.")

    print("Conversion complete!")

if __name__ == "__main__":
    convert_megatron_to_hf_custom()