"""Diagnostic script to check 8B config for potential issues."""

import pickle
from pathlib import Path

# Load the converted config
config_path = Path('checkpoints/llama-3.1-8b-instruct-flax/config.pkl')
print(f"Loading config from: {config_path}")
config = pickle.load(open(config_path, 'rb'))

print("\n" + "="*70)
print("LLAMA 3.1 8B CONFIG DIAGNOSTIC")
print("="*70)

print("\nArchitecture:")
print(f"  vocab_size: {config.vocab_size}")
print(f"  hidden_size: {config.hidden_size}")
print(f"  intermediate_size: {config.intermediate_size}")
print(f"  num_hidden_layers: {config.num_hidden_layers}")
print(f"  num_attention_heads: {config.num_attention_heads}")
print(f"  num_key_value_heads: {config.num_key_value_heads}")
print(f"  head_dim: {config.hidden_size // config.num_attention_heads}")

print("\nRoPE Configuration:")
print(f"  rope_theta: {config.rope_theta}")
print(f"  rope_scaling: {config.rope_scaling}")
print(f"  max_position_embeddings: {config.max_position_embeddings}")

print("\nNormalization:")
print(f"  rms_norm_eps: {config.rms_norm_eps}")

print("\nWeight Tying:")
print(f"  tie_word_embeddings: {config.tie_word_embeddings}")

print("\nDtype Configuration:")
print(f"  dtype: {config.dtype}")
print(f"  param_dtype: {config.param_dtype}")

print("\n" + "="*70)
print("EXPECTED VALUES FOR LLAMA 3.1 8B INSTRUCT")
print("="*70)

expected = {
    'vocab_size': 128256,
    'hidden_size': 4096,
    'intermediate_size': 14336,
    'num_hidden_layers': 32,
    'num_attention_heads': 32,
    'num_key_value_heads': 8,
    'rope_theta': 500000.0,
    'rms_norm_eps': 1e-5,
    'tie_word_embeddings': False,  # 8B should be False, 1B is True
}

print("\nComparison:")
all_match = True
for key, expected_val in expected.items():
    actual_val = getattr(config, key)
    match = actual_val == expected_val
    symbol = "✓" if match else "✗"
    print(f"  {symbol} {key}: {actual_val} {'==' if match else '!='} {expected_val}")
    if not match:
        all_match = False

print("\n" + "="*70)
if all_match:
    print("✓ All config values match expected!")
else:
    print("✗ MISMATCH DETECTED - this could cause inference issues")
print("="*70)

# Check rope scaling details
print("\nRoPE Scaling Details:")
if config.rope_scaling:
    print(f"  Type: {config.rope_scaling.get('rope_type')}")
    print(f"  Factor: {config.rope_scaling.get('factor')}")
    print(f"  Low freq factor: {config.rope_scaling.get('low_freq_factor')}")
    print(f"  High freq factor: {config.rope_scaling.get('high_freq_factor')}")
    print(f"  Original max pos: {config.rope_scaling.get('original_max_position_embeddings')}")
    print("\n  Note: RoPE scaling should activate for sequences > 8192 tokens")
    print("  Your evaluation uses 16K-32K tokens, so scaling IS active")
else:
    print("  ✗ WARNING: No RoPE scaling configured!")
