#!/usr/bin/env python3
"""
Extract model information from DCP metadata to understand the model structure
"""

import torch
import torch.distributed.checkpoint as dcp
import pickle
import sys
import os


def extract_metadata_info(checkpoint_path: str):
    """Extract model configuration info from DCP metadata"""
    
    metadata_path = os.path.join(checkpoint_path, ".metadata")
    if not os.path.exists(metadata_path):
        print("No metadata file found!")
        return None
    
    # Load the metadata
    try:
        with open(metadata_path, 'rb') as f:
            metadata = pickle.load(f)
        print("Successfully loaded metadata")
    except Exception as e:
        print(f"Failed to load metadata: {e}")
        return None
    
    print(f"Metadata type: {type(metadata)}")
    
    # Extract state dict metadata
    if hasattr(metadata, 'state_dict_metadata'):
        state_dict_metadata = metadata.state_dict_metadata
        print(f"Found {len(state_dict_metadata)} parameters in metadata")
        
        # Analyze the structure
        config = analyze_model_structure(state_dict_metadata)
        return config
    else:
        print("No state_dict_metadata found")
        return None

def analyze_model_structure(state_dict_metadata):
    """Analyze the model structure from metadata"""
    
    config = {
        'vocab_size': None,
        'hidden_size': None,
        'n_layers': 0,
        'intermediate_size': None,
        'n_heads': None,
        'n_kv_heads': None,
        'parameter_shapes': {}
    }
    
    print("\nAnalyzing parameter shapes:")
    
    for param_name, param_metadata in state_dict_metadata.items():
        if hasattr(param_metadata, 'size'):
            shape = tuple(param_metadata.size)
            config['parameter_shapes'][param_name] = shape
            print(f"  {param_name}: {shape}")
            
            # Extract configuration from specific parameters
            if param_name == "tok_embeddings.weight":
                config['vocab_size'] = shape[0]
                config['hidden_size'] = shape[1]
            
            elif param_name == "layers.0.attention.wq.weight":
                if config['hidden_size'] is None:
                    config['hidden_size'] = shape[1]
                # For query weights: [hidden_size, hidden_size] typically
                
            elif param_name == "layers.0.attention.wk.weight":
                # For key weights: [kv_hidden_size, hidden_size] 
                kv_hidden_size = shape[0]
                if config['hidden_size']:
                    config['n_kv_heads'] = kv_hidden_size // (config['hidden_size'] // 32)  # Assume head_dim = 128
                
            elif param_name == "layers.0.feed_forward.w1.weight":
                config['intermediate_size'] = shape[0]
            
            elif param_name.startswith("layers.") and ".attention.wq.weight" in param_name:
                layer_num = int(param_name.split(".")[1])
                config['n_layers'] = max(config['n_layers'], layer_num + 1)
    
    # Calculate number of heads
    if config['hidden_size']:
        head_dim = 128  # Standard for most models
        config['n_heads'] = config['hidden_size'] // head_dim
        
        if config['n_kv_heads'] is None:
            config['n_kv_heads'] = config['n_heads']  # No GQA
    
    print(f"\nExtracted configuration:")
    for key, value in config.items():
        if key != 'parameter_shapes':
            print(f"  {key}: {value}")
    
    return config

def create_state_dict_template(config):
    """Create a state dict template with correct shapes"""
    
    state_dict = {}
    
    # Token embeddings
    if config['vocab_size'] and config['hidden_size']:
        state_dict["tok_embeddings.weight"] = torch.empty(config['vocab_size'], config['hidden_size'])
    
    # Model layers
    for i in range(config['n_layers']):
        # Attention weights
        if config['hidden_size']:
            state_dict[f"layers.{i}.attention.wq.weight"] = torch.empty(config['hidden_size'], config['hidden_size'])
            state_dict[f"layers.{i}.attention.wo.weight"] = torch.empty(config['hidden_size'], config['hidden_size'])
            
            # For GQA, k/v projections have different sizes
            kv_hidden = config['n_kv_heads'] * (config['hidden_size'] // config['n_heads'])
            state_dict[f"layers.{i}.attention.wk.weight"] = torch.empty(kv_hidden, config['hidden_size'])
            state_dict[f"layers.{i}.attention.wv.weight"] = torch.empty(kv_hidden, config['hidden_size'])
        
        # FFN weights
        if config['intermediate_size'] and config['hidden_size']:
            state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.empty(config['intermediate_size'], config['hidden_size'])
            state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.empty(config['hidden_size'], config['intermediate_size'])
            state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.empty(config['intermediate_size'], config['hidden_size'])
        
        # Layer norms
        if config['hidden_size']:
            state_dict[f"layers.{i}.attention_norm.weight"] = torch.empty(config['hidden_size'])
            state_dict[f"layers.{i}.ffn_norm.weight"] = torch.empty(config['hidden_size'])
    
    # Output layers
    if config['hidden_size']:
        state_dict["norm.weight"] = torch.empty(config['hidden_size'])
    if config['vocab_size'] and config['hidden_size']:
        state_dict["output.weight"] = torch.empty(config['vocab_size'], config['hidden_size'])
    
    # freqs_cis (optional)
    state_dict["freqs_cis"] = torch.empty(2048, config['hidden_size'] // config['n_heads'] // 2, dtype=torch.complex64)
    
    return state_dict


def load_checkpoint_with_correct_shapes(checkpoint_path: str):
    """Load checkpoint using the correct shapes"""
    
    print("Step 1: Extracting model configuration from metadata...")
    config = extract_metadata_info(checkpoint_path)
    
    if not config:
        print("Failed to extract configuration")
        return None
    
    print("\nStep 2: Creating state dict template with correct shapes...")
    state_dict = create_state_dict_template(config)
    
    print(f"Created template with {len(state_dict)} parameters")
    
    print("\nStep 3: Loading checkpoint...")
    try:
        dcp.load(state_dict, checkpoint_id=checkpoint_path)
        print("Successfully loaded checkpoint!")
        
        # Verify loaded parameters
        loaded_params = {k: v for k, v in state_dict.items() if v.numel() > 0}
        print(f"Loaded {len(loaded_params)} parameters")
        
        # Show a few examples
        for k, v in list(loaded_params.items())[:5]:
            print(f"  {k}: {v.shape}")
        
        return state_dict
        
    except Exception as e:
        print(f"Failed to load: {e}")
        return None


if __name__ == "__main__":
    checkpoint_path = "/data"
    
    if len(sys.argv) > 1:
        checkpoint_path = sys.argv[1]
    
    result = load_checkpoint_with_correct_shapes(checkpoint_path)
    
    if result:
        print("\n✓ Success! Checkpoint loaded successfully")
    else:
        print("\n✗ Failed to load checkpoint") 