#!/usr/bin/env python3
"""
Convert TorchTitan DCP checkpoint to HuggingFace format.

This script loads a distributed checkpoint from TorchTitan training
and converts it to a standard HuggingFace transformers checkpoint.
"""

import argparse
import json
import os
import shutil
import torch
import torch.distributed.checkpoint as dcp
from pathlib import Path
from typing import Dict, Any
from transformers import LlamaConfig, LlamaForCausalLM
import time

def create_hf_config_from_torchtitan(tt_state_dict: Dict[str, Any], extracted_config: Dict[str, Any] = None) -> LlamaConfig:
    """
    Create HuggingFace LlamaConfig from TorchTitan state dict.
    Uses pre-extracted config if available, otherwise infers from state dict.
    """
    if extracted_config:
        # Use the extracted configuration from metadata
        config = LlamaConfig(
            vocab_size=extracted_config['vocab_size'],
            hidden_size=extracted_config['hidden_size'],
            intermediate_size=extracted_config['intermediate_size'],
            num_hidden_layers=extracted_config['n_layers'],
            num_attention_heads=extracted_config['n_heads'],
            num_key_value_heads=extracted_config['n_kv_heads'],
            hidden_act="silu",
            max_position_embeddings=131072,  # Can be adjusted based on your training
            initializer_range=0.02,
            rms_norm_eps=1e-5,
            use_cache=True,
            pad_token_id=None,
            bos_token_id=1,
            eos_token_id=2,
            tie_word_embeddings=False,
            rope_theta=500000.0,  # Adjust based on your model
            rope_scaling=None,
            attention_bias=False,
        )
        
        print(f"Using extracted config:")
        print(f"  vocab_size: {extracted_config['vocab_size']}")
        print(f"  hidden_size: {extracted_config['hidden_size']}")
        print(f"  num_hidden_layers: {extracted_config['n_layers']}")
        print(f"  num_attention_heads: {extracted_config['n_heads']}")
        print(f"  num_key_value_heads: {extracted_config['n_kv_heads']}")
        print(f"  intermediate_size: {extracted_config['intermediate_size']}")
        
        return config
    
    # Fallback: Extract dimensions from the state dict
    embed_weight = tt_state_dict.get("tok_embeddings.weight")
    if embed_weight is None:
        raise ValueError("Could not find tok_embeddings.weight in state dict")
    
    vocab_size, hidden_size = embed_weight.shape
    
    # Count the number of layers
    n_layers = 0
    for key in tt_state_dict.keys():
        if key.startswith("layers.") and ".attention.wq.weight" in key:
            layer_num = int(key.split(".")[1])
            n_layers = max(n_layers, layer_num + 1)
    
    # Get attention dimensions
    head_dim = 128  # Standard for Llama
    num_attention_heads = hidden_size // head_dim
    
    # Check if we have grouped query attention
    wk_weight = tt_state_dict.get(f"layers.0.attention.wk.weight")
    if wk_weight is not None:
        kv_channels, _ = wk_weight.shape
        num_key_value_heads = kv_channels // head_dim
    else:
        num_key_value_heads = num_attention_heads  # No GQA
    
    # Get FFN intermediate size
    ffn_weight_key = f"layers.0.feed_forward.w1.weight"
    if ffn_weight_key in tt_state_dict:
        ffn_weight = tt_state_dict[ffn_weight_key]
        intermediate_size, _ = ffn_weight.shape
    else:
        # Standard Llama ratio
        intermediate_size = int(hidden_size * 8 / 3)
        # Round to nearest multiple of 256
        intermediate_size = ((intermediate_size + 255) // 256) * 256
    
    config = LlamaConfig(
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        num_hidden_layers=n_layers,
        num_attention_heads=num_attention_heads,
        num_key_value_heads=num_key_value_heads,
        hidden_act="silu",
        max_position_embeddings=131072,  # Can be adjusted based on your training
        initializer_range=0.02,
        rms_norm_eps=1e-5,
        use_cache=True,
        pad_token_id=None,
        bos_token_id=1,
        eos_token_id=2,
        tie_word_embeddings=False,
        rope_theta=500000.0,  # Adjust based on your model
        rope_scaling=None,
        attention_bias=False,
    )
    
    print(f"Inferred config:")
    print(f"  vocab_size: {vocab_size}")
    print(f"  hidden_size: {hidden_size}")
    print(f"  num_hidden_layers: {n_layers}")
    print(f"  num_attention_heads: {num_attention_heads}")
    print(f"  num_key_value_heads: {num_key_value_heads}")
    print(f"  intermediate_size: {intermediate_size}")
    
    return config

def convert_torchtitan_to_hf_state_dict(tt_state_dict: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    """
    Convert TorchTitan state dict to HuggingFace format.
    
    TorchTitan naming -> HuggingFace naming:
    - tok_embeddings.weight -> model.embed_tokens.weight
    - layers.{i}.attention.wq.weight -> model.layers.{i}.self_attn.q_proj.weight
    - layers.{i}.attention.wk.weight -> model.layers.{i}.self_attn.k_proj.weight  
    - layers.{i}.attention.wv.weight -> model.layers.{i}.self_attn.v_proj.weight
    - layers.{i}.attention.wo.weight -> model.layers.{i}.self_attn.o_proj.weight
    - layers.{i}.feed_forward.w1.weight -> model.layers.{i}.mlp.gate_proj.weight
    - layers.{i}.feed_forward.w2.weight -> model.layers.{i}.mlp.down_proj.weight
    - layers.{i}.feed_forward.w3.weight -> model.layers.{i}.mlp.up_proj.weight
    - layers.{i}.attention_norm.weight -> model.layers.{i}.input_layernorm.weight
    - layers.{i}.ffn_norm.weight -> model.layers.{i}.post_attention_layernorm.weight
    - norm.weight -> model.norm.weight
    - output.weight -> lm_head.weight
    """
    hf_state_dict = {}
    
    for key, value in tt_state_dict.items():
        if key == "freqs_cis":
            # Skip this buffer as HF computes it dynamically
            continue
        elif key == "tok_embeddings.weight":
            hf_state_dict["model.embed_tokens.weight"] = value
        elif key == "norm.weight":
            hf_state_dict["model.norm.weight"] = value
        elif key == "output.weight":
            hf_state_dict["lm_head.weight"] = value
        elif key.startswith("layers."):
            parts = key.split(".")
            layer_idx = parts[1]
            
            if key.endswith(".attention.wq.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = value
            elif key.endswith(".attention.wk.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = value
            elif key.endswith(".attention.wv.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = value
            elif key.endswith(".attention.wo.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = value
            elif key.endswith(".feed_forward.w1.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.weight"] = value
            elif key.endswith(".feed_forward.w2.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = value
            elif key.endswith(".feed_forward.w3.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.mlp.up_proj.weight"] = value
            elif key.endswith(".attention_norm.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = value
            elif key.endswith(".ffn_norm.weight"):
                hf_state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = value
            else:
                print(f"Warning: Unknown layer parameter {key}")
        else:
            print(f"Warning: Unknown parameter {key}")
    
    return hf_state_dict

def extract_metadata_info(checkpoint_path: str):
    """Extract model configuration info from DCP metadata"""
    import pickle
    
    metadata_path = os.path.join(checkpoint_path, ".metadata")
    if not os.path.exists(metadata_path):
        raise ValueError("No metadata file found!")
    
    # Load the metadata
    with open(metadata_path, 'rb') as f:
        metadata = pickle.load(f)
    
    if not hasattr(metadata, 'state_dict_metadata'):
        raise ValueError("No state_dict_metadata found")
    
    state_dict_metadata = metadata.state_dict_metadata
    
    # Analyze the structure
    config = {
        'vocab_size': None,
        'hidden_size': None,
        'n_layers': 0,
        'intermediate_size': None,
        'n_heads': None,
        'n_kv_heads': None,
    }
    
    for param_name, param_metadata in state_dict_metadata.items():
        if hasattr(param_metadata, 'size'):
            shape = tuple(param_metadata.size)
            
            if param_name == "tok_embeddings.weight":
                config['vocab_size'] = shape[0]
                config['hidden_size'] = shape[1]
            
            elif param_name == "layers.0.attention.wk.weight":
                # For key weights: [kv_hidden_size, hidden_size] 
                kv_hidden_size = shape[0]
                if config['hidden_size']:
                    head_dim = 128  # Standard
                    config['n_kv_heads'] = kv_hidden_size // head_dim
                
            elif param_name == "layers.0.feed_forward.w1.weight":
                config['intermediate_size'] = shape[0]
            
            elif param_name.startswith("layers.") and ".attention.wq.weight" in param_name:
                layer_num = int(param_name.split(".")[1])
                config['n_layers'] = max(config['n_layers'], layer_num + 1)
    
    # Calculate number of heads
    if config['hidden_size']:
        head_dim = 128  # Standard for most models
        config['n_heads'] = config['hidden_size'] // head_dim
        
        if config['n_kv_heads'] is None:
            config['n_kv_heads'] = config['n_heads']  # No GQA
    
    return config

def create_state_dict_template(config):
    """Create a state dict template with correct shapes"""
    
    state_dict = {}
    
    # Token embeddings
    state_dict["tok_embeddings.weight"] = torch.empty(config['vocab_size'], config['hidden_size'])
    
    # Model layers
    for i in range(config['n_layers']):
        # Attention weights
        state_dict[f"layers.{i}.attention.wq.weight"] = torch.empty(config['hidden_size'], config['hidden_size'])
        state_dict[f"layers.{i}.attention.wo.weight"] = torch.empty(config['hidden_size'], config['hidden_size'])
        
        # For GQA, k/v projections have different sizes
        kv_hidden = config['n_kv_heads'] * (config['hidden_size'] // config['n_heads'])
        state_dict[f"layers.{i}.attention.wk.weight"] = torch.empty(kv_hidden, config['hidden_size'])
        state_dict[f"layers.{i}.attention.wv.weight"] = torch.empty(kv_hidden, config['hidden_size'])
        
        # FFN weights
        state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.empty(config['intermediate_size'], config['hidden_size'])
        state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.empty(config['hidden_size'], config['intermediate_size'])
        state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.empty(config['intermediate_size'], config['hidden_size'])
        
        # Layer norms
        state_dict[f"layers.{i}.attention_norm.weight"] = torch.empty(config['hidden_size'])
        state_dict[f"layers.{i}.ffn_norm.weight"] = torch.empty(config['hidden_size'])
    
    # Output layers
    state_dict["norm.weight"] = torch.empty(config['hidden_size'])
    state_dict["output.weight"] = torch.empty(config['vocab_size'], config['hidden_size'])
    
    # Note: freqs_cis is not included as it's not persistent in the checkpoint
    
    return state_dict

def load_dcp_checkpoint(checkpoint_path: str) -> Dict[str, torch.Tensor]:
    """
    Load a distributed checkpoint and return the model state dict.
    """
    print(f"Loading DCP checkpoint from: {checkpoint_path}")
    
    # Step 1: Extract model configuration from metadata
    print("Extracting model configuration from metadata...")
    config = extract_metadata_info(checkpoint_path)
    
    print(f"Model configuration:")
    print(f"  vocab_size: {config['vocab_size']}")
    print(f"  hidden_size: {config['hidden_size']}")  
    print(f"  n_layers: {config['n_layers']}")
    print(f"  n_heads: {config['n_heads']}")
    print(f"  n_kv_heads: {config['n_kv_heads']}")
    print(f"  intermediate_size: {config['intermediate_size']}")
    
    # Step 2: Create state dict template with correct shapes
    print("Creating state dict template...")
    state_dict = create_state_dict_template(config)
    
    # Step 3: Load the checkpoint
    print("Loading checkpoint...")
    begin = time.time()
    dcp.load(state_dict, checkpoint_id=checkpoint_path)
    print(f"Checkpoint loaded in {time.time() - begin:.2f} seconds")
    
    # Verify loaded parameters
    loaded_params = {k: v for k, v in state_dict.items() if v.numel() > 0}
    print(f"Loaded {len(loaded_params)} parameters")
    print("Sample parameters:", list(loaded_params.keys())[:5])
    
    return state_dict

def add_llama3_tokenizer_files(output_path: Path):
    """
    Add Llama-3 tokenizer files to the converted model directory.
    
    Since TorchTitan uses tiktoken and HuggingFace expects PreTrainedTokenizerFast,
    we'll use the tokenizer from an existing Llama-3 model.
    """
    try:
        from transformers import AutoTokenizer
        
        # Use the official Llama-3-8B tokenizer as reference
        print("Downloading Llama-3 tokenizer from HuggingFace...")
        
        # Try to load from the official Llama-3-8B model
        try:
            tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
        except Exception as e:
            print(f"Could not load from meta-llama/Meta-Llama-3-8B: {e}")
            print("Trying alternative...")
            # Fallback to any Llama-3 model with the same vocab
            try:
                tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
            except Exception as e2:
                print(f"Could not load from NousResearch/Meta-Llama-3-8B: {e2}")
                print("Creating minimal tokenizer config...")
                create_minimal_tokenizer_config(output_path)
                return
        
        # Save the tokenizer to our output directory
        tokenizer.save_pretrained(output_path)
        print("✓ Llama-3 tokenizer files added successfully")
        
    except ImportError as e:
        print(f"Warning: Could not import transformers: {e}")
        create_minimal_tokenizer_config(output_path)
    except Exception as e:
        print(f"Warning: Could not add tokenizer files: {e}")
        create_minimal_tokenizer_config(output_path)


def create_minimal_tokenizer_config(output_path: Path):
    """
    Create minimal tokenizer configuration files.
    """
    print("Creating minimal tokenizer configuration...")
    
    # Create tokenizer_config.json
    tokenizer_config = {
        "add_bos_token": True,
        "add_eos_token": False,
        "added_tokens_decoder": {
            "128000": {
                "content": "<|begin_of_text|>",
                "lstrip": False,
                "normalized": False,
                "rstrip": False,
                "single_word": False,
                "special": True
            },
            "128001": {
                "content": "<|end_of_text|>",
                "lstrip": False,
                "normalized": False,
                "rstrip": False,
                "single_word": False,
                "special": True
            }
        },
        "additional_special_tokens": [],
        "bos_token": "<|begin_of_text|>",
        "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|start_header_id|>system<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}\n{% endif %}\n{% if loop.last and message['role'] != 'assistant' %}\n{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}\n{% endif %}\n{% endfor %}",
        "clean_up_tokenization_spaces": False,
        "eos_token": "<|end_of_text|>",
        "legacy": True,
        "model_max_length": 131072,
        "pad_token": None,
        "sp_model_kwargs": {},
        "spaces_between_special_tokens": False,
        "tokenizer_class": "LlamaTokenizer",
        "unk_token": None,
        "use_default_system_prompt": False,
        "vocab_size": 128256
    }
    
    # Save tokenizer config
    with open(output_path / "tokenizer_config.json", "w") as f:
        json.dump(tokenizer_config, f, indent=2)
    
    print("⚠️  Minimal tokenizer config created.")
    print("   Note: You'll need to provide the actual tokenizer files")
    print("   (tokenizer.json, vocab files) to use the tokenizer properly.")
    print("   Consider downloading them from: meta-llama/Meta-Llama-3-8B")


def convert_dcp_to_hf(
    dcp_checkpoint_path: str,
    output_path: str,
    config_file: str = None,
    push_to_hub: bool = False,
    repo_name: str = None
):
    """
    Main conversion function.
    """
    checkpoint_path = Path(dcp_checkpoint_path)
    output_path = Path(output_path)
    
    if not checkpoint_path.exists():
        raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
    
    # Load the DCP checkpoint and extract config
    print("Step 1: Loading TorchTitan DCP checkpoint...")
    
    # First extract the config from metadata for better loading
    extracted_config = extract_metadata_info(str(checkpoint_path))
    tt_state_dict = load_dcp_checkpoint(str(checkpoint_path))
    
    print(f"Loaded {len(tt_state_dict)} parameters")
    print("Sample parameters:", list(tt_state_dict.keys())[:5])
    
    # Create HuggingFace config
    print("\nStep 2: Creating HuggingFace config...")
    if config_file and Path(config_file).exists():
        print(f"Loading config from {config_file}")
        with open(config_file, 'r') as f:
            config_dict = json.load(f)
        config = LlamaConfig.from_dict(config_dict)
    else:
        print("Using extracted config from checkpoint metadata...")
        config = create_hf_config_from_torchtitan(tt_state_dict, extracted_config)
    
    # Convert state dict
    print("\nStep 3: Converting state dict to HuggingFace format...")
    hf_state_dict = convert_torchtitan_to_hf_state_dict(tt_state_dict)
    
    print(f"Converted to {len(hf_state_dict)} HF parameters")
    print("Sample HF parameters:", list(hf_state_dict.keys())[:5])
    
    # Create HuggingFace model and load weights
    print("\nStep 4: Creating HuggingFace model...")
    model = LlamaForCausalLM(config)
    
    # Load the converted state dict
    missing_keys, unexpected_keys = model.load_state_dict(hf_state_dict, strict=False)
    
    if missing_keys:
        print(f"Warning: Missing keys: {missing_keys}")
    if unexpected_keys:
        print(f"Warning: Unexpected keys: {unexpected_keys}")
    
    # Save the model
    print(f"\nStep 5: Saving HuggingFace model to {output_path}...")
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Save model and config
    model.save_pretrained(output_path)
    config.save_pretrained(output_path)
    
    # Add tokenizer files for Llama-3
    print("Adding Llama-3 tokenizer files...")
    add_llama3_tokenizer_files(output_path)
    
    print(f"Model saved successfully to {output_path}")
    
    # Optionally push to hub
    if push_to_hub and repo_name:
        print(f"\nStep 6: Pushing to HuggingFace Hub as {repo_name}...")
        model.push_to_hub(repo_name)
        config.push_to_hub(repo_name)
        print("Successfully pushed to Hub!")

def main():
    parser = argparse.ArgumentParser(description="Convert TorchTitan DCP checkpoint to HuggingFace format")
    parser.add_argument(
        "checkpoint_path",
        type=str,
        help="Path to the TorchTitan DCP checkpoint directory (e.g., /path/to/step-10000)"
    )
    parser.add_argument(
        "output_path", 
        type=str,
        help="Output path for the HuggingFace model"
    )
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Optional: Path to HuggingFace config.json file. If not provided, config will be inferred."
    )
    parser.add_argument(
        "--push-to-hub",
        action="store_true",
        help="Push the converted model to HuggingFace Hub"
    )
    parser.add_argument(
        "--repo-name",
        type=str,
        help="Repository name for HuggingFace Hub (required if --push-to-hub is used)"
    )
    
    args = parser.parse_args()
    
    if args.push_to_hub and not args.repo_name:
        parser.error("--repo-name is required when --push-to-hub is used")
    
    convert_dcp_to_hf(
        args.checkpoint_path,
        args.output_path,
        args.config,
        args.push_to_hub,
        args.repo_name
    )

if __name__ == "__main__":
    main() 