#!/usr/bin/env python3
"""
Utility script to print model parameter names for debugging LoRA compatibility issues.
"""

import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel


def print_model_parameters(model_path, lora_path=None):
    """
    Print all parameter names in a model, optionally with LoRA adapters.
    
    Args:
        model_path (str): Path to the base model
        lora_path (str, optional): Path to LoRA adapters
    """
    print(f"Loading model from: {model_path}")
    
    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )
    
    print("\n=== Base Model Parameters ===")
    param_names = []
    for name, param in model.named_parameters():
        param_names.append(name)
        print(f"{name}: {param.shape}")
    
    if lora_path and os.path.exists(lora_path):
        print(f"\n=== Loading LoRA Adapters from: {lora_path} ===")
        # Load LoRA adapters
        model = PeftModel.from_pretrained(model, lora_path)
        
        print("\n=== Model with LoRA Parameters ===")
        lora_param_names = []
        for name, param in model.named_parameters():
            lora_param_names.append(name)
            print(f"{name}: {param.shape}")
            
        # Show what's different
        print("\n=== LoRA Specific Parameters ===")
        lora_specific = set(lora_param_names) - set(param_names)
        for name in lora_specific:
            print(f"{name}")
    else:
        if lora_path:
            print(f"\nLoRA path not found or not provided: {lora_path}")
        else:
            print("\nNo LoRA path provided")
    
    return model


def print_embedding_layer_details(model):
    """
    Print detailed information about embedding layers.
    """
    print("\n=== Embedding Layer Details ===")
    for name, module in model.named_modules():
        if "embed" in name.lower() or "embedding" in name.lower():
            print(f"Module: {name} - Type: {type(module)}")
            if hasattr(module, 'weight'):
                print(f"  Weight shape: {module.weight.shape}")
            if hasattr(module, 'num_embeddings'):
                print(f"  Num embeddings: {module.num_embeddings}")
            if hasattr(module, 'embedding_dim'):
                print(f"  Embedding dim: {module.embedding_dim}")


def compare_with_vllm_expectations(model):
    """
    Compare model parameter names with typical vLLM expectations.
    """
    print("\n=== vLLM Compatibility Check ===")
    vllm_expected_patterns = [
        "model.embed_tokens.weight",
        "model.layers.",
        "model.norm.weight",
        "lm_head.weight"
    ]
    
    model_params = [name for name, _ in model.named_parameters()]
    
    for pattern in vllm_expected_patterns:
        matches = [p for p in model_params if pattern in p or p == pattern]
        if matches:
            print(f"Found expected pattern '{pattern}':")
            for match in matches[:3]:  # Show first 3 matches
                print(f"  - {match}")
            if len(matches) > 3:
                print(f"  ... and {len(matches) - 3} more")
        else:
            print(f"Warning: No matches found for expected pattern '{pattern}'")


def main():
    parser = argparse.ArgumentParser(description="Print model parameter names for debugging")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the base model")
    parser.add_argument("--lora_path", type=str, help="Path to LoRA adapters (optional)")
    parser.add_argument("--check_vllm", action="store_true", help="Check vLLM compatibility")
    parser.add_argument("--detail_embed", action="store_true", help="Print detailed embedding info")
    
    args = parser.parse_args()
    
    if not os.path.exists(args.model_path):
        print(f"Error: Model path does not exist: {args.model_path}")
        return
    
    try:
        model = print_model_parameters(args.model_path, args.lora_path)
        
        if args.detail_embed:
            print_embedding_layer_details(model)
            
        if args.check_vllm:
            compare_with_vllm_expectations(model)
            
    except Exception as e:
        print(f"Error loading model: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()