#!/usr/bin/env python
"""
Debug script to understand Llama-3 GQA impact on quantization
"""

import torch
from transformers import AutoModelForCausalLM, AutoConfig
import sys

def analyze_model_structure(model_name):
    """Analyze model structure for rotation compatibility"""
    print(f"\nAnalyzing {model_name}")
    print("=" * 60)
    
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="cpu"
    )
    
    # Check attention structure
    layer = model.model.layers[0]
    
    print(f"Hidden size: {config.hidden_size}")
    print(f"Num attention heads: {config.num_attention_heads}")
    print(f"Num KV heads: {getattr(config, 'num_key_value_heads', config.num_attention_heads)}")
    print(f"Head dim: {config.hidden_size // config.num_attention_heads}")
    
    # Check projection dimensions
    print(f"\nProjection dimensions:")
    print(f"  Q proj: {layer.self_attn.q_proj.weight.shape}")
    print(f"  K proj: {layer.self_attn.k_proj.weight.shape}")
    print(f"  V proj: {layer.self_attn.v_proj.weight.shape}")
    print(f"  O proj: {layer.self_attn.o_proj.weight.shape}")
    
    # Check MLP dimensions
    print(f"\nMLP dimensions:")
    print(f"  Gate proj: {layer.mlp.gate_proj.weight.shape}")
    print(f"  Up proj: {layer.mlp.up_proj.weight.shape}")
    print(f"  Down proj: {layer.mlp.down_proj.weight.shape}")
    
    # Check if dimensions are power of 2
    dims_to_check = [
        ("Hidden", config.hidden_size),
        ("Q/O", layer.self_attn.q_proj.weight.shape[0]),
        ("K/V", layer.self_attn.k_proj.weight.shape[0]),
        ("MLP intermediate", config.intermediate_size)
    ]
    
    print(f"\nDimension analysis for butterfly:")
    for name, dim in dims_to_check:
        is_pow2 = (dim & (dim - 1)) == 0
        print(f"  {name}: {dim} {'✓ (power of 2)' if is_pow2 else '✗ (NOT power of 2)'}")
    
    return config

if __name__ == "__main__":
    # Analyze both models
    config_7b = analyze_model_structure("meta-llama/Llama-2-7b-hf")
    config_8b = analyze_model_structure("meta-llama/Meta-Llama-3-8B")
    
    print("\n" + "=" * 60)
    print("KEY DIFFERENCES:")
    print("=" * 60)
    
    print(f"\n1. GQA (Grouped Query Attention):")
    print(f"   Llama-2-7B: {config_7b.num_attention_heads} Q heads, {getattr(config_7b, 'num_key_value_heads', config_7b.num_attention_heads)} KV heads")
    print(f"   Llama-3-8B: {config_8b.num_attention_heads} Q heads, {config_8b.num_key_value_heads} KV heads")
    print(f"   → Llama-3 uses 4:1 grouped attention!")
    
    print(f"\n2. MLP size:")
    print(f"   Llama-2-7B: {config_7b.intermediate_size}")
    print(f"   Llama-3-8B: {config_8b.intermediate_size}")
    
    print(f"\n3. Impact on rotation:")
    print(f"   - K/V projections in Llama-3 are only 1024-dim (not 4096)")
    print(f"   - This affects how rotations should be applied")
    print(f"   - MLP size 14336 is not power of 2 (problematic for butterfly)")