"""Test if rope_scaling config affects perplexity."""

import pickle
import jax.numpy as jnp
import jax
from pathlib import Path
from transformers import AutoTokenizer

from fma_llama.model.llama import LlamaForCausalLM

# Load model and params
checkpoint_path = Path("checkpoints/llama-3.2-1b-flax")
with open(checkpoint_path / "config.pkl", 'rb') as f:
    config = pickle.load(f)
with open(checkpoint_path / "flax_params.pkl", 'rb') as f:
    params = pickle.load(f)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B')

# Test text at 2K context
test_text = "The quick brown fox jumps over the lazy dog. " * 200  # ~2K tokens
tokens = tokenizer.encode(test_text, add_special_tokens=False)[:2048]
input_ids = jnp.array([tokens])

# Test with rope_scaling = None
print("Testing with rope_scaling = None")
config.rope_scaling = None
model = LlamaForCausalLM(config)
logits = model.apply(params, input_ids)

# Compute perplexity
shift_logits = logits[:, :-1, :]
shift_labels = input_ids[:, 1:]
loss = jnp.mean(
    jnp.sum(
        -jax.nn.log_softmax(shift_logits, axis=-1) * jax.nn.one_hot(shift_labels, shift_logits.shape[-1]),
        axis=-1
    )
)
perplexity_no_scaling = jnp.exp(loss)
print(f"Perplexity (no rope_scaling): {perplexity_no_scaling:.4f}")

# Test with rope_scaling from HF config
print("\nTesting with rope_scaling from HF")
config.rope_scaling = {
    'factor': 32.0,
    'high_freq_factor': 4.0,
    'low_freq_factor': 1.0,
    'original_max_position_embeddings': 8192,
    'rope_type': 'llama3'
}
model = LlamaForCausalLM(config)
logits = model.apply(params, input_ids)

# Compute perplexity
shift_logits = logits[:, :-1, :]
shift_labels = input_ids[:, 1:]
loss = jnp.mean(
    jnp.sum(
        -jax.nn.log_softmax(shift_logits, axis=-1) * jax.nn.one_hot(shift_labels, shift_logits.shape[-1]),
        axis=-1
    )
)
perplexity_with_scaling = jnp.exp(loss)
print(f"Perplexity (with rope_scaling config): {perplexity_with_scaling:.4f}")
print(f"\nDifference: {perplexity_with_scaling - perplexity_no_scaling:.4f}")
