"""Evaluate model on LongBench v2 with FMA attention comparison."""

import argparse
from pathlib import Path
import json
import pickle
from typing import Dict, List, Tuple
import math

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm

from fma_llama.model.config import LlamaConfig
from fma_llama.model.llama import LlamaForCausalLM


# ============================================================================
# COMPONENT 1: PROMPT CONSTRUCTION & TOKENIZATION
# ============================================================================

def build_prompt(
    context: str,
    question: str,
    choice_a: str,
    choice_b: str,
    choice_c: str,
    choice_d: str,
) -> str:
    """Build LongBench v2 prompt from components.

    Args:
        context: Document/context text
        question: Question text
        choice_a/b/c/d: Answer choices

    Returns:
        Full prompt string
    """
    prompt = f"""Read the following document and answer the question.

Document:
{context}

Question: {question}
A. {choice_a}
B. {choice_b}
C. {choice_c}
D. {choice_d}

Please answer with "The correct answer is (insert answer here)."."""

    return prompt


def tokenize_sample(
    sample: Dict,
    tokenizer,
    answer_choice: str,
    use_chat_format: bool = True,
) -> Tuple[List[int], int]:
    """Tokenize a LongBench v2 sample with complete answer string.

    Args:
        sample: Dictionary with context, question, choices
        tokenizer: HuggingFace tokenizer
        answer_choice: The answer choice (A/B/C/D) to append
        use_chat_format: Whether to apply chat template (for instruct models)

    Returns:
        Tuple of (token_ids, eval_position)
        eval_position is where to evaluate logits (last token position)
    """
    # Build base prompt
    prompt = build_prompt(
        context=sample['context'],
        question=sample['question'],
        choice_a=sample['choice_A'],
        choice_b=sample['choice_B'],
        choice_c=sample['choice_C'],
        choice_d=sample['choice_D'],
    )

    if use_chat_format:
        # Apply chat template for instruction-tuned models
        messages = [{"role": "user", "content": prompt}]
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,  # Adds assistant header
            tokenize=False,
        )
        # Append complete answer string with proper formatting
        formatted_prompt = formatted_prompt + f"The correct answer is {answer_choice}. "
        token_ids = tokenizer.encode(formatted_prompt, add_special_tokens=False)
    else:
        prompt = prompt + f"\n\nThe correct answer is {answer_choice}. "
        token_ids = tokenizer.encode(prompt, add_special_tokens=True)

    # Evaluation position is the last token (the answer)
    eval_position = len(token_ids) - 1

    return token_ids, eval_position


def test_prompt_construction():
    """Test prompt construction and tokenization."""
    print("="*70)
    print("COMPONENT 1 TEST: Prompt Construction & Tokenization")
    print("="*70)

    # Load dataset
    print("\nLoading LongBench v2...")
    dataset = load_dataset('THUDM/LongBench-v2', split='train')

    # Load tokenizer
    tokenizer_name = 'meta-llama/Llama-3.1-8B-Instruct'
    print(f"Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    # Get answer token IDs by tokenizing complete answer strings
    print("\nDetermining answer token IDs:")
    answer_tokens = {}

    # Create a dummy sample to get the base prompt
    sample = dataset[0]

    for choice in ['A', 'B', 'C', 'D']:
        # Build full prompt with answer (including period and space)
        prompt = build_prompt(
            context="[context]",  # Dummy context
            question="[question]",
            choice_a="[choice_a]",
            choice_b="[choice_b]",
            choice_c="[choice_c]",
            choice_d="[choice_d]",
        )

        messages = [{"role": "user", "content": prompt}]
        formatted = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )

        # Tokenize the full sequence with answer
        full_text = formatted + f"The correct answer is {choice}. "
        tokens = tokenizer.encode(full_text, add_special_tokens=False)

        # Tokenize without the answer to find where it starts
        # Don't include the choice letter at all
        prefix_text = formatted + "The correct answer is"  # No trailing space!
        prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)

        # The answer token (e.g., " A", " B", etc.) is at len(prefix_tokens)
        answer_tokens[choice] = tokens[len(prefix_tokens)]

        # Also decode to verify
        decoded = tokenizer.decode([answer_tokens[choice]])
        print(f"  Answer '{choice}' -> token ID: {answer_tokens[choice]} (decodes to: {repr(decoded)})")

    print(f"\nAnswer token mapping: {answer_tokens}")

    # Test on first 3 samples
    print("\n" + "-"*70)
    for idx in range(min(3, len(dataset))):
        sample = dataset[idx]

        print(f"\nSample {idx + 1}:")
        print(f"  Domain: {sample['domain']}")
        print(f"  Difficulty: {sample['difficulty']}")
        print(f"  Length category: {sample['length']}")
        print(f"  Ground truth: {sample['answer']}")

        # Tokenize with ground truth answer
        token_ids, eval_pos = tokenize_sample(
            sample, tokenizer, answer_choice=sample['answer'], use_chat_format=True
        )

        print(f"  Total tokens: {len(token_ids)}")
        print(f"  Eval position: {eval_pos}")

        # Show last part of prompt
        last_tokens = token_ids[-50:]
        decoded = tokenizer.decode(last_tokens)
        print(f"  Last 50 tokens decoded:")
        print(f"    {repr(decoded)}")

        # Show the very last token (the answer)
        final_token = token_ids[-1]
        print(f"  Final token ID: {final_token} -> '{tokenizer.decode([final_token])}'")
        print(f"  Expected: should match answer_tokens['{sample['answer']}'] = {answer_tokens.get(sample['answer'], 'N/A')}")

        # Verify it matches
        if final_token == answer_tokens.get(sample['answer']):
            print(f"  ✓ Match confirmed!")
        else:
            print(f"  ✗ MISMATCH - something is wrong!")

        # Show full prompt for first sample
        if idx == 0:
            print(f"\n  FULL PROMPT (Sample 1):")
            print(f"  {'-'*66}")
            full_prompt = tokenizer.decode(token_ids)
            # Truncate context for display
            if len(full_prompt) > 2000:
                print(f"  {full_prompt[:1000]}")
                print(f"  [...{len(full_prompt)-2000} chars truncated...]")
                print(f"  {full_prompt[-1000:]}")
            else:
                print(f"  {full_prompt}")
            print(f"  {'-'*66}")

    print("\n" + "="*70)
    print("Component 1 test complete!")
    print("="*70)
    print("\nVerify:")
    print("  1. Answer tokens A/B/C/D are single token IDs")
    print("  2. Prompt format looks correct")
    print("  3. Chat template is applied (should see <|start_header_id|> etc)")
    print("  4. Eval position is at end of prompt")


# ============================================================================
# COMPONENT 2: POWER-OF-2 BUCKETING
# ============================================================================

def filter_samples_by_length(
    dataset,
    tokenizer,
    min_length: int,
    max_length: int,
    max_samples: int = None,
) -> List[Dict]:
    """Filter samples to a specific power-of-2 length range.

    Args:
        dataset: HuggingFace dataset
        tokenizer: Tokenizer to use
        min_length: Minimum sequence length (inclusive)
        max_length: Maximum sequence length (exclusive)
        max_samples: Maximum samples to keep (None for all)

    Returns:
        List of samples with metadata
    """
    print(f"\nFiltering samples in range [{min_length}, {max_length})...")

    filtered = []
    length_distribution = []

    for sample in tqdm(dataset, desc="Filtering by length"):
        # Tokenize without answer to get base length
        # We'll add answer later for evaluation
        prompt = build_prompt(
            context=sample['context'],
            question=sample['question'],
            choice_a=sample['choice_A'],
            choice_b=sample['choice_B'],
            choice_c=sample['choice_C'],
            choice_d=sample['choice_D'],
        )

        messages = [{"role": "user", "content": prompt}]
        formatted = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )
        # Tokenize with a dummy answer to get full length
        # Use A as dummy since all answers tokenize to same number of tokens
        full_formatted = formatted + "The correct answer is A. "
        full_tokens = tokenizer.encode(full_formatted, add_special_tokens=False)
        full_length = len(full_tokens)

        # Store the prefix for later use
        prefix_formatted = formatted + "The correct answer is "
        base_tokens = tokenizer.encode(prefix_formatted, add_special_tokens=False)

        length_distribution.append(full_length)

        # Check if in range
        if min_length <= full_length < max_length:
            filtered.append({
                'sample': sample,
                'base_tokens': base_tokens,
                'full_length': full_length,
            })

            if max_samples and len(filtered) >= max_samples:
                break

    print(f"\nBucketing Results:")
    print(f"  Total samples processed: {len(length_distribution)}")
    print(f"  Samples in range [{min_length}, {max_length}): {len(filtered)}")

    if filtered:
        lengths = [s['full_length'] for s in filtered]
        print(f"  Min length in bucket: {min(lengths):,}")
        print(f"  Max length in bucket: {max(lengths):,}")
        print(f"  Mean length in bucket: {sum(lengths)/len(lengths):,.1f}")

    # Show overall distribution
    import numpy as np
    length_arr = np.array(length_distribution)
    print(f"\nOverall length distribution:")
    print(f"  Min: {length_arr.min():,}")
    print(f"  Max: {length_arr.max():,}")
    print(f"  Median: {np.median(length_arr):,.0f}")

    return filtered


def test_bucketing():
    """Test power-of-2 bucketing."""
    print("="*70)
    print("COMPONENT 2 TEST: Power-of-2 Bucketing")
    print("="*70)

    # Load dataset
    print("\nLoading LongBench v2...")
    dataset = load_dataset('THUDM/LongBench-v2', split='train')

    # Load tokenizer
    tokenizer_name = 'meta-llama/Llama-3.1-8B-Instruct'
    print(f"Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    # Filter to 2^14 - 2^15 range (16,384 - 32,768 tokens)
    min_length = 2**14
    max_length = 2**15
    max_samples = 10

    filtered = filter_samples_by_length(
        dataset,
        tokenizer,
        min_length=min_length,
        max_length=max_length,
        max_samples=max_samples,
    )

    print(f"\n{'='*70}")
    print(f"Filtered {len(filtered)} samples")
    print(f"{'='*70}")

    # Show sample details
    print(f"\nSample details:")
    for i, item in enumerate(filtered[:5]):  # Show first 5
        s = item['sample']
        print(f"\n  Sample {i+1}:")
        print(f"    ID: {s['_id']}")
        print(f"    Domain: {s['domain']}")
        print(f"    Difficulty: {s['difficulty']}")
        print(f"    Length category: {s['length']}")
        print(f"    Tokenized length: {item['full_length']:,}")
        print(f"    Ground truth: {s['answer']}")

    print("\n" + "="*70)
    print("Component 2 test complete!")
    print("="*70)
    print("\nVerify:")
    print(f"  1. All samples are in range [{min_length:,}, {max_length:,})")
    print(f"  2. Sample metadata looks correct")
    print(f"  3. Got {len(filtered)} samples (capped at {max_samples})")


# ============================================================================
# COMPONENT 3: CYCLIC PADDING
# ============================================================================

def apply_cyclic_padding(
    full_tokens_with_answer: List[int],
    context_tokens: List[int],
    target_length: int,
    eval_position: int,
) -> Tuple[List[int], int]:
    """Apply cyclic padding to reach target power-of-2 length.

    Args:
        full_tokens_with_answer: Complete sequence including "The correct answer is X. "
        context_tokens: Just the context/document tokens (for cycling)
        target_length: Target sequence length (power of 2)
        eval_position: Position of the answer token in full_tokens_with_answer

    Returns:
        Tuple of (padded_tokens, eval_position)
        - padded_tokens: Sequence of exactly target_length
        - eval_position: Position of answer token (unchanged)
    """
    current_length = len(full_tokens_with_answer)

    if current_length >= target_length:
        # No padding needed, just truncate if necessary
        # Keep everything up to and including answer, then truncate
        return full_tokens_with_answer[:target_length], min(eval_position, target_length - 1)

    # Calculate padding needed AFTER the complete answer
    padding_needed = target_length - current_length

    # Create cyclic padding from context
    if len(context_tokens) == 0:
        raise ValueError("Context tokens empty, cannot pad")

    # Repeat context cyclically
    num_repeats = (padding_needed // len(context_tokens)) + 1
    cyclic_padding = (context_tokens * num_repeats)[:padding_needed]

    # Construct final sequence: full_tokens + padding
    # This way truncating the padding gives us the original sequence
    padded_tokens = full_tokens_with_answer + cyclic_padding

    assert len(padded_tokens) == target_length, \
        f"Length mismatch: {len(padded_tokens)} != {target_length}"

    return padded_tokens, eval_position


def extract_context_tokens(sample: Dict, tokenizer) -> List[int]:
    """Extract just the context/document tokens for cyclic padding.

    Args:
        sample: Dataset sample
        tokenizer: Tokenizer

    Returns:
        Token IDs for just the context
    """
    # Tokenize just the context
    context_tokens = tokenizer.encode(sample['context'], add_special_tokens=False)
    return context_tokens


def test_cyclic_padding():
    """Test cyclic padding implementation."""
    print("="*70)
    print("COMPONENT 3 TEST: Cyclic Padding")
    print("="*70)

    # Load dataset
    print("\nLoading LongBench v2...")
    dataset = load_dataset('THUDM/LongBench-v2', split='train')

    # Load tokenizer
    tokenizer_name = 'meta-llama/Llama-3.1-8B-Instruct'
    print(f"Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    # Get answer tokens
    answer_tokens = {}
    for choice in ['A', 'B', 'C', 'D']:
        prompt = build_prompt("[ctx]", "[q]", "[a]", "[b]", "[c]", "[d]")
        messages = [{"role": "user", "content": prompt}]
        formatted = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

        # Tokenize without the choice letter to find where answer token starts
        prefix_text = formatted + "The correct answer is"  # No trailing space!
        prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)

        full_text = formatted + f"The correct answer is {choice}. "
        tokens = tokenizer.encode(full_text, add_special_tokens=False)
        answer_tokens[choice] = tokens[len(prefix_tokens)]

    # Filter samples
    min_length = 2**14
    max_length = 2**15
    target_length = 2**15  # Pad to 32,768

    filtered = filter_samples_by_length(
        dataset, tokenizer, min_length, max_length, max_samples=3
    )

    print(f"\n{'='*70}")
    print(f"Testing cyclic padding on {len(filtered)} samples")
    print(f"Target length: {target_length:,}")
    print(f"{'='*70}")

    # Test padding on each sample
    for i, item in enumerate(filtered):
        sample = item['sample']
        base_tokens = item['base_tokens']
        original_length = item['full_length']

        print(f"\n{'─'*70}")
        print(f"Sample {i+1}:")
        print(f"  Original length: {original_length:,}")
        print(f"  Padding needed: {target_length - original_length:,}")

        # Extract context tokens
        context_tokens = extract_context_tokens(sample, tokenizer)
        print(f"  Context tokens: {len(context_tokens):,}")

        # Get full sequence with answer
        answer_choice = sample['answer']
        full_tokens, _ = tokenize_sample(sample, tokenizer, answer_choice, use_chat_format=True)

        # To find the eval position, we need to tokenize without the answer
        # and see where the answer token would be inserted
        prompt = build_prompt(
            context=sample['context'],
            question=sample['question'],
            choice_a=sample['choice_A'],
            choice_b=sample['choice_B'],
            choice_c=sample['choice_C'],
            choice_d=sample['choice_D'],
        )
        messages = [{"role": "user", "content": prompt}]
        formatted = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )
        # Tokenize up to but not including the answer letter
        prefix_text = formatted + "The correct answer is"
        prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)

        # The answer token is at position len(prefix_tokens)
        eval_pos = len(prefix_tokens)

        # Verify it's correct
        answer_token = answer_tokens[answer_choice]
        if eval_pos < len(full_tokens) and full_tokens[eval_pos] == answer_token:
            print(f"  ✓ Answer token '{answer_choice}' (ID {answer_token}) at position {eval_pos}")
        else:
            print(f"  ERROR: Position {eval_pos} has token {full_tokens[eval_pos] if eval_pos < len(full_tokens) else 'OUT_OF_BOUNDS'}")
            print(f"  Expected: {answer_token}")
            print(f"  Last 10 tokens: {full_tokens[-10:]}")
            raise ValueError("Answer token position mismatch")

        # Apply padding
        padded_tokens, eval_pos = apply_cyclic_padding(
            full_tokens_with_answer=full_tokens,
            context_tokens=context_tokens,
            target_length=target_length,
            eval_position=eval_pos,
        )

        print(f"  Padded length: {len(padded_tokens):,}")
        print(f"  Eval position: {eval_pos:,}")
        assert len(padded_tokens) == target_length

        # Verify structure
        print(f"\n  First 50 tokens (should be start of prompt):")
        first_50 = tokenizer.decode(padded_tokens[:50])
        print(f"    {repr(first_50[:200])}")

        # Show tokens around the answer position
        print(f"\n  Tokens around answer position [{eval_pos-20}:{eval_pos+20}]:")
        around_answer = tokenizer.decode(padded_tokens[max(0, eval_pos-20):min(len(padded_tokens), eval_pos+20)])
        print(f"    {repr(around_answer)}")

        print(f"\n  Last 50 tokens (should be padding - cyclic context):")
        last_50 = tokenizer.decode(padded_tokens[-50:])
        print(f"    {repr(last_50[:200])}")

        # Show padding region
        padding_start = eval_pos + 1  # After answer token
        padding_end = len(padded_tokens)
        padding_length = padding_end - padding_start

        print(f"\n  Padding region: tokens [{padding_start:,} : {padding_end:,}]")
        print(f"  Padding length: {padding_length:,}")

        if padding_length > 0:
            # Show a snippet from padding
            snippet_start = padding_start
            snippet_end = min(snippet_start + 50, padding_end)
            padding_snippet = tokenizer.decode(padded_tokens[snippet_start:snippet_end])
            print(f"  First 50 tokens of padding:")
            print(f"    {repr(padding_snippet[:200])}")

            # Verify it's actually context content
            context_start = tokenizer.decode(context_tokens[:50])
            print(f"  Original context start (for comparison):")
            print(f"    {repr(context_start[:200])}")

        # Verify answer token is at eval position
        answer_at_pos = padded_tokens[eval_pos]
        expected_token = answer_tokens[sample['answer']]
        print(f"\n  Token at eval position {eval_pos}: {answer_at_pos} (expected: {expected_token})")
        print(f"  Decoded: '{tokenizer.decode([answer_at_pos])}'")
        assert answer_at_pos == expected_token, "Answer token mismatch!"
        print(f"  ✓ Answer token verified!")

        # Verify truncation gives original sequence
        truncated = padded_tokens[:eval_pos + 1]
        print(f"\n  Truncated length (removing padding): {len(truncated):,}")
        print(f"  Original length: {original_length:,}")
        print(f"  ✓ Truncation verified! (answer at position {eval_pos}, padding after)")

    print("\n" + "="*70)
    print("Component 3 test complete!")
    print("="*70)
    print("\nVerify:")
    print(f"  1. All padded sequences are exactly {target_length:,} tokens")
    print(f"  2. Padding consists of cyclic context repetition")
    print(f"  3. Question/choices/answer are intact at the end")
    print(f"  4. Final token matches ground truth answer")


# ============================================================================
# COMPONENT 4: MODEL LOADING & SHARDING
# ============================================================================

def shard_params(params, mesh):
    """Shard loaded parameters according to model annotations.

    Args:
        params: Loaded parameter tree
        mesh: Device mesh

    Returns:
        Sharded parameters
    """
    def get_sharding_for_param(path, array):
        """Determine sharding based on parameter path."""
        path_str = '/'.join(str(p) for p in path)

        # MLP gate/up projections: shard output dim
        if 'gate_proj' in path_str or 'up_proj' in path_str:
            if 'kernel' in path_str:
                return P(None, 'model')
        # MLP down projection: shard input dim
        elif 'down_proj' in path_str:
            if 'kernel' in path_str:
                return P('model', None)
        # Attention projections: shard on heads/hidden dim
        elif 'Dense' in path_str:  # Q/K/V/O projections
            if 'kernel' in path_str:
                return P(None, 'model')
        # Embeddings: shard hidden dim
        elif 'Embed' in path_str or 'embedding' in path_str:
            return P(None, 'model')

        # Default: replicate
        return P(None, None) if len(array.shape) == 2 else P(None,)

    def shard_tree(tree, path=()):
        """Recursively shard parameter tree."""
        if isinstance(tree, dict):
            return {k: shard_tree(v, path + (k,)) for k, v in tree.items()}
        elif isinstance(tree, jnp.ndarray):
            spec = get_sharding_for_param(path, tree)
            sharding = NamedSharding(mesh, spec)
            return jax.device_put(tree, sharding)
        else:
            return tree

    return shard_tree(params)


def load_model_and_config(checkpoint_dir: str, mesh=None):
    """Load Flax model and config from checkpoint directory.

    Args:
        checkpoint_dir: Path to checkpoint
        mesh: Optional device mesh for sharding params

    Returns:
        Tuple of (model, params, config)
    """
    checkpoint_path = Path(checkpoint_dir)

    print(f"Loading config from {checkpoint_path / 'config.pkl'}")
    with open(checkpoint_path / "config.pkl", 'rb') as f:
        config = pickle.load(f)

    print(f"Loading parameters from {checkpoint_path / 'flax_params.pkl'}")
    with open(checkpoint_path / "flax_params.pkl", 'rb') as f:
        params = pickle.load(f)

    print(f"Creating model with config:")
    print(f"  Hidden size: {config.hidden_size}")
    print(f"  Num layers: {config.num_hidden_layers}")
    print(f"  Num heads: {config.num_attention_heads}")
    print(f"  Vocab size: {config.vocab_size}")

    model = LlamaForCausalLM(config)

    # Shard params if mesh provided
    if mesh is not None:
        print("Sharding parameters across devices...")
        params = shard_params(params, mesh)

    return model, params, config


def test_model_loading():
    """Test model loading and sharding."""
    print("="*70)
    print("COMPONENT 4 TEST: Model Loading & Sharding")
    print("="*70)

    # Setup mesh
    print("\nSetting up device mesh...")
    devices = mesh_utils.create_device_mesh((1, jax.device_count()))
    mesh = Mesh(devices, axis_names=('data', 'model'))
    print(f"Mesh shape: {mesh.shape}")
    print(f"Devices: {jax.devices()}")
    jax.set_mesh(mesh)

    # Load model
    checkpoint_dir = 'checkpoints/llama-3.2-1b-flax'
    print(f"\nLoading model from {checkpoint_dir}")

    model, params, config = load_model_and_config(checkpoint_dir, mesh=mesh)

    print("\n" + "="*70)
    print("Model loaded successfully!")
    print("="*70)

    # Show parameter structure first
    print("\nParameter structure:")
    print(f"  Type: {type(params)}")
    print(f"  Keys: {list(params.keys()) if isinstance(params, dict) else 'N/A'}")

    # Count parameters
    def count_params(tree, path=''):
        """Count total parameters."""
        if isinstance(tree, dict):
            total = 0
            for k, v in tree.items():
                total += count_params(v, f"{path}/{k}" if path else k)
            return total
        elif isinstance(tree, jnp.ndarray):
            return tree.size
        elif hasattr(tree, 'shape'):  # Handle sharded arrays
            import numpy as np
            return np.prod(tree.shape)
        else:
            return 0

    total_params = count_params(params)
    print(f"\nTotal parameters: {total_params:,} ({total_params/1e6:.1f}M)")

    # Show some parameter shapes
    print("\nSample parameter shapes (first 10):")
    def show_shapes(tree, path='', max_show=10):
        """Show parameter shapes."""
        count = [0]  # Use list for mutable counter

        def traverse(t, p):
            if count[0] >= max_show:
                return

            if isinstance(t, dict):
                for k, v in t.items():
                    if count[0] >= max_show:
                        break
                    new_path = f"{p}/{k}" if p else k
                    traverse(v, new_path)
            elif isinstance(t, jnp.ndarray) or hasattr(t, 'shape'):
                shape = t.shape if hasattr(t, 'shape') else 'unknown'
                print(f"  {p}: {shape}")
                count[0] += 1

        traverse(tree, path)

    show_shapes(params)

    print("\n" + "="*70)
    print("Component 4 test complete!")
    print("="*70)
    print("\nVerify:")
    print("  1. Model loaded without errors")
    print("  2. Parameters are sharded across devices")
    print("  3. Parameter count looks reasonable")


# ============================================================================
# COMPONENT 5: LOGIT EXTRACTION FOR MULTIPLE CHOICE
# ============================================================================

def evaluate_sample_logits(
    model,
    params,
    input_tokens: List[int],
    eval_position: int,
    answer_tokens: Dict[str, int],
    mesh,
    tokenizer=None,
) -> Dict:
    """Evaluate a single sample by extracting logits at eval position.

    Args:
        model: Flax model
        params: Model parameters
        input_tokens: Tokenized input sequence
        eval_position: Position where answer token is
        answer_tokens: Mapping from answer choice to token ID
        mesh: Device mesh
        tokenizer: Optional tokenizer for decoding top predictions

    Returns:
        Dictionary with logits, probabilities, and prediction
    """
    # Convert to JAX array and add batch dimension
    input_ids = jnp.array([input_tokens])

    # Shard input
    input_sharding = NamedSharding(mesh, P('data', None))
    input_ids = jax.device_put(input_ids, input_sharding)

    # Forward pass
    logits = model.apply(params, input_ids)  # (batch, seq_len, vocab)

    # Extract logits to predict token at eval position
    # logits[i] predicts token at position i+1, so we need logits[eval_position - 1]
    logits_at_pos = logits[0, eval_position - 1, :]  # (vocab,)

    # Get top-k predictions across full vocabulary
    top_k = 10
    top_k_indices = jnp.argsort(logits_at_pos)[-top_k:][::-1]  # Top 10 in descending order
    top_k_logits = logits_at_pos[top_k_indices]

    # Decode top-k tokens if tokenizer provided
    top_k_info = []
    if tokenizer is not None:
        for idx, logit in zip(top_k_indices, top_k_logits):
            token_id = int(idx)
            decoded = tokenizer.decode([token_id])
            top_k_info.append({
                'token_id': token_id,
                'logit': float(logit),
                'decoded': decoded,
            })

    # Extract logits for answer tokens
    answer_logits = {
        choice: float(logits_at_pos[token_id])
        for choice, token_id in answer_tokens.items()
    }

    # Compute probabilities over answer choices
    answer_logits_array = jnp.array([answer_logits[c] for c in ['A', 'B', 'C', 'D']])
    answer_probs = jax.nn.softmax(answer_logits_array)
    answer_probs_dict = {
        choice: float(prob)
        for choice, prob in zip(['A', 'B', 'C', 'D'], answer_probs)
    }

    # Prediction is argmax
    prediction = ['A', 'B', 'C', 'D'][int(jnp.argmax(answer_probs))]

    # Overall argmax
    argmax_token_id = int(jnp.argmax(logits_at_pos))
    argmax_decoded = tokenizer.decode([argmax_token_id]) if tokenizer else None

    return {
        'logits': answer_logits,
        'probabilities': answer_probs_dict,
        'prediction': prediction,
        'top_k_predictions': top_k_info,
        'argmax_token_id': argmax_token_id,
        'argmax_decoded': argmax_decoded,
    }


def test_logit_extraction(max_samples=5, checkpoint_dir='checkpoints/llama-3.2-1b-flax'):
    """Test logit extraction on a few samples.

    Args:
        max_samples: Maximum number of samples to evaluate
        checkpoint_dir: Path to model checkpoint
    """
    print("="*70)
    print("COMPONENT 5 TEST: Logit Extraction for Multiple Choice")
    print("="*70)

    # Setup mesh
    print("\nSetting up device mesh...")
    devices = mesh_utils.create_device_mesh((1, jax.device_count()))
    mesh = Mesh(devices, axis_names=('data', 'model'))
    jax.set_mesh(mesh)

    # Load model
    print(f"Loading model from {checkpoint_dir}")
    model, params, config = load_model_and_config(checkpoint_dir, mesh=mesh)

    # Load dataset and tokenizer
    print("\nLoading dataset and tokenizer...")
    dataset = load_dataset('THUDM/LongBench-v2', split='train')
    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B-Instruct')

    # Get answer tokens - ensure we use tokens with leading space
    answer_tokens = {}
    print("\nDetermining answer tokens:")
    for choice in ['A', 'B', 'C', 'D']:
        prompt = build_prompt("[ctx]", "[q]", "[a]", "[b]", "[c]", "[d]")
        messages = [{"role": "user", "content": prompt}]
        formatted = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

        # The answer should have a leading space: "The correct answer is A"
        # When tokenized, "is" ends, then " A" is the next token
        prefix_text = formatted + "The correct answer is"
        prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)
        full_text = formatted + f"The correct answer is {choice}. "
        tokens = tokenizer.encode(full_text, add_special_tokens=False)
        answer_tokens[choice] = tokens[len(prefix_tokens)]

        decoded = tokenizer.decode([answer_tokens[choice]])
        print(f"  '{choice}' -> token ID {answer_tokens[choice]}, decodes to: {repr(decoded)}")

    print(f"\nAnswer token mapping: {answer_tokens}")

    # Filter to get samples for testing
    print(f"\nFiltering samples (max {max_samples})...")
    min_length = 2**14
    max_length = 2**15
    filtered = filter_samples_by_length(dataset, tokenizer, min_length, max_length, max_samples=max_samples)

    print(f"\n{'='*70}")
    print(f"Testing on {len(filtered)} samples")
    print(f"{'='*70}")

    target_length = 2**15

    # Test on each sample
    results = []
    for i, item in enumerate(filtered):
        sample = item['sample']
        print(f"\n{'─'*70}")
        print(f"Sample {i+1}:")
        print(f"  ID: {sample['_id']}")
        print(f"  Domain: {sample['domain']}")
        print(f"  Ground truth: {sample['answer']}")

        # Get full sequence with padding
        answer_choice = sample['answer']
        full_tokens, _ = tokenize_sample(sample, tokenizer, answer_choice, use_chat_format=True)

        # Find eval position
        prompt = build_prompt(
            context=sample['context'],
            question=sample['question'],
            choice_a=sample['choice_A'],
            choice_b=sample['choice_B'],
            choice_c=sample['choice_C'],
            choice_d=sample['choice_D'],
        )
        messages = [{"role": "user", "content": prompt}]
        formatted = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        prefix_text = formatted + "The correct answer is"
        prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)
        eval_pos = len(prefix_tokens)

        # Apply padding
        context_tokens = extract_context_tokens(sample, tokenizer)
        padded_tokens, eval_pos = apply_cyclic_padding(
            full_tokens_with_answer=full_tokens,
            context_tokens=context_tokens,
            target_length=target_length,
            eval_position=eval_pos,
        )

        print(f"  Sequence length: {len(padded_tokens):,}")
        print(f"  Eval position: {eval_pos:,}")

        # Show what token is actually at eval position
        token_at_eval = padded_tokens[eval_pos]
        decoded_at_eval = tokenizer.decode([token_at_eval])
        print(f"  Token at eval position: ID {token_at_eval}, decodes to: {repr(decoded_at_eval)}")
        print(f"  Expected for answer '{answer_choice}': ID {answer_tokens[answer_choice]}")

        if token_at_eval != answer_tokens[answer_choice]:
            print(f"  ⚠️  WARNING: Token at eval position doesn't match expected answer token!")

        # Run inference
        print(f"  Running forward pass...")
        result = evaluate_sample_logits(
            model, params, padded_tokens, eval_pos, answer_tokens, mesh, tokenizer=tokenizer
        )

        # Show top predictions across full vocab
        print(f"\n  Top 10 predictions (full vocab):")
        for i, pred in enumerate(result['top_k_predictions'][:10], 1):
            print(f"    {i}. Token {pred['token_id']:6d} (logit={pred['logit']:7.2f}): {repr(pred['decoded'])}")

        # Show argmax
        print(f"\n  Argmax prediction: Token {result['argmax_token_id']} = {repr(result['argmax_decoded'])}")

        # Show answer choices
        print(f"\n  Answer choice logits:")
        for choice in ['A', 'B', 'C', 'D']:
            logit = result['logits'][choice]
            prob = result['probabilities'][choice]
            print(f"    {choice}: logit={logit:7.2f}, prob={prob:.3f}")

        print(f"\n  Prediction (among A/B/C/D): {result['prediction']}")
        print(f"  Ground truth: {sample['answer']}")
        print(f"  Correct: {'✓' if result['prediction'] == sample['answer'] else '✗'}")

        results.append({
            'sample_id': sample['_id'],
            'ground_truth': sample['answer'],
            'prediction': result['prediction'],
            'correct': result['prediction'] == sample['answer'],
            'probabilities': result['probabilities'],
        })

    # Summary
    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    num_correct = sum(r['correct'] for r in results)
    accuracy = num_correct / len(results) if results else 0
    print(f"\nTotal samples: {len(results)}")
    print(f"Correct: {num_correct}")
    print(f"Accuracy: {accuracy:.1%}")

    # Breakdown by prediction
    from collections import Counter
    predictions = Counter(r['prediction'] for r in results)
    ground_truths = Counter(r['ground_truth'] for r in results)

    print(f"\nPrediction distribution:")
    for choice in ['A', 'B', 'C', 'D']:
        print(f"  {choice}: {predictions.get(choice, 0)}")

    print(f"\nGround truth distribution:")
    for choice in ['A', 'B', 'C', 'D']:
        print(f"  {choice}: {ground_truths.get(choice, 0)}")

    print("\n" + "="*70)
    print("Component 5 test complete!")
    print("="*70)
    print("\nVerify:")
    print("  1. Forward pass runs without errors")
    print("  2. Logits are extracted at correct position")
    print("  3. Predictions are made from logits")
    print("  4. Accuracy is reasonable (>25% = better than random)")

    return accuracy, results


# ============================================================================
# COMPONENT 6 & 7: MULTI-CONFIG EVALUATION & RESULTS
# ============================================================================

# Define FMA configurations to compare (from compare_fma_configs.py)
FMA_CONFIGS = [
    {
        'name': 'Standard Attention',
        'use_fma': False,
        'block_size': None,
        'num_clusters': None,
        'num_retrievals': None,
        'bidiagonal': False,
        'dipole': False,
    },
    {
        'name': 'FMA-8k-64-0',
        'use_fma': True,
        'block_size': 2**13,
        'num_clusters': 64,
        'num_retrievals': None,
        'bidiagonal': False,
        'dipole': False,
    },
    {
        'name': 'FMA-8k-64-4',
        'use_fma': True,
        'block_size': 2**13,
        'num_clusters': 64,
        'num_retrievals': 4,
        'bidiagonal': False,
        'dipole': False,
    },
    {
        'name': 'FMA-8k-128-8',
        'use_fma': True,
        'block_size': 2**13,
        'num_clusters': 128,
        'num_retrievals': 8,
        'bidiagonal': False,
        'dipole': False,
    },
    {
        'name': 'FMA-8k-256-16',
        'use_fma': True,
        'block_size': 2**13,
        'num_clusters': 256,
        'num_retrievals': 16,
        'bidiagonal': False,
        'dipole': False,
    },
]
FMA_CONFIGS=FMA_CONFIGS[1:2]


def evaluate_longbench_v2(
    checkpoint_dir: str,
    output_file: str = None,
    octave: int = None,
    min_length: int = 2**14,
    max_length: int = 2**15,
    max_samples: int = 10,
):
    """Full evaluation on LongBench v2 with multiple FMA configurations.

    Args:
        checkpoint_dir: Path to model checkpoint
        output_file: Output JSON file path (default: auto-generated)
        octave: Power of 2 for max_length (overrides min_length/max_length if provided)
                e.g., octave=15 → [16K, 32K) range
        min_length: Minimum sequence length (ignored if octave is set)
        max_length: Maximum sequence length (ignored if octave is set)
        max_samples: Maximum samples to evaluate per config
    """
    # If octave specified, compute min/max from it
    if octave is not None:
        max_length = 2**octave
        min_length = 2**(octave - 1)

    print("="*70)
    print("LongBench v2 Evaluation with FMA Config Comparison")
    print("="*70)

    # Auto-generate output filename if not provided
    if output_file is None:
        model_name = Path(checkpoint_dir).name
        output_file = f"longbenchv2_{model_name}_fma_comparison.json"

    print(f"\nCheckpoint: {checkpoint_dir}")
    print(f"Output file: {output_file}")
    print(f"Sequence length range: [{min_length:,}, {max_length:,})")
    print(f"Max samples per config: {max_samples}")
    print(f"Target padded length: {max_length:,}")

    # Setup mesh
    print("\nSetting up device mesh...")
    devices = mesh_utils.create_device_mesh((1, jax.device_count()))
    mesh = Mesh(devices, axis_names=('data', 'model'))
    jax.set_mesh(mesh)

    # Load model (will be reconfigured for each FMA config)
    print(f"\nLoading model...")
    model, params, base_config = load_model_and_config(checkpoint_dir, mesh=mesh)

    # Load dataset and tokenizer
    print("\nLoading dataset and tokenizer...")
    dataset = load_dataset('THUDM/LongBench-v2', split='train')
    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B-Instruct')

    # Get answer tokens
    answer_tokens = {}
    for choice in ['A', 'B', 'C', 'D']:
        prompt = build_prompt("[ctx]", "[q]", "[a]", "[b]", "[c]", "[d]")
        messages = [{"role": "user", "content": prompt}]
        formatted = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        prefix_text = formatted + "The correct answer is"
        prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)
        full_text = formatted + f"The correct answer is {choice}. "
        tokens = tokenizer.encode(full_text, add_special_tokens=False)
        answer_tokens[choice] = tokens[len(prefix_tokens)]

    # Filter samples once
    print(f"\nFiltering samples...")
    filtered = filter_samples_by_length(dataset, tokenizer, min_length, max_length, max_samples=max_samples)
    print(f"Evaluating {len(filtered)} samples")

    # Prepare padded samples once (they're the same for all configs)
    print("\nPreparing padded samples...")
    prepared_samples = []
    for item in tqdm(filtered, desc="Padding samples"):
        sample = item['sample']
        answer_choice = sample['answer']

        # Get full tokens and eval position
        full_tokens, _ = tokenize_sample(sample, tokenizer, answer_choice, use_chat_format=True)
        prompt = build_prompt(
            context=sample['context'],
            question=sample['question'],
            choice_a=sample['choice_A'],
            choice_b=sample['choice_B'],
            choice_c=sample['choice_C'],
            choice_d=sample['choice_D'],
        )
        messages = [{"role": "user", "content": prompt}]
        formatted = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        prefix_text = formatted + "The correct answer is"
        prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)
        eval_pos = len(prefix_tokens)

        # Apply padding
        context_tokens = extract_context_tokens(sample, tokenizer)
        padded_tokens, eval_pos = apply_cyclic_padding(
            full_tokens_with_answer=full_tokens,
            context_tokens=context_tokens,
            target_length=max_length,
            eval_position=eval_pos,
        )

        prepared_samples.append({
            'sample': sample,
            'padded_tokens': padded_tokens,
            'eval_pos': eval_pos,
        })

    # Evaluate each configuration
    all_results = {}

    for config_idx, fma_config in enumerate(FMA_CONFIGS):
        print(f"\n{'='*70}")
        print(f"CONFIG {config_idx + 1}/{len(FMA_CONFIGS)}: {fma_config['name']}")
        print(f"{'='*70}")

        # Update model config
        base_config.use_fma_attention = fma_config['use_fma']
        if fma_config['use_fma']:
            base_config.fma_block_size = fma_config['block_size']
            base_config.fma_num_clusters = fma_config['num_clusters']
            base_config.fma_num_retrievals = fma_config['num_retrievals']
            base_config.fma_bidiagonal = fma_config['bidiagonal']

        # JIT compile forward pass once for this config
        # All samples have the same shape (1, max_length)
        print("JIT compiling forward pass...")

        def forward_pass(input_ids, eval_pos):
            """Forward pass that extracts logits at eval position on-device."""
            logits = model.apply(params, input_ids)
            #jax.debug.inspect_array_sharding(logits, callback=print)
            #return logits
            # Extract logits at eval position (do indexing on-device)
            # logits[i] predicts token at position i+1, so we need logits[eval_position - 1]
            logits_at_pos = logits[0, eval_pos - 1, :]
            return logits_at_pos

        # Compile with concrete input shape
        jit_forward = jax.jit(forward_pass, out_shardings=NamedSharding(mesh, P(None,)))

        # Warm-up compilation with first sample
        if prepared_samples:
            dummy_input = jnp.array([prepared_samples[0]['padded_tokens']])
            input_sharding = NamedSharding(mesh, P('data', None))
            dummy_input = jax.device_put(dummy_input, input_sharding)
            dummy_eval_pos = prepared_samples[0]['eval_pos']
            _ = jit_forward(dummy_input, dummy_eval_pos)  # Trigger compilation
            print("JIT compilation complete!")

        # Evaluate all samples with this config
        config_results = []
        input_sharding = NamedSharding(mesh, P('data', None))

        for prep in tqdm(prepared_samples, desc=f"Evaluating {fma_config['name']}"):
            # Prepare input
            input_ids = jnp.array([prep['padded_tokens']])
            input_ids = jax.device_put(input_ids, input_sharding)

            # Forward pass (already compiled) - returns logits at eval position
            logits_at_pos = jit_forward(input_ids, prep['eval_pos'])

            # Extract logits for answer tokens
            answer_logits = {
                choice: float(logits_at_pos[token_id])
                for choice, token_id in answer_tokens.items()
            }

            # Compute probabilities over answer choices
            answer_logits_array = jnp.array([answer_logits[c] for c in ['A', 'B', 'C', 'D']])
            answer_probs = jax.nn.softmax(answer_logits_array)
            answer_probs_dict = {
                choice: float(prob)
                for choice, prob in zip(['A', 'B', 'C', 'D'], answer_probs)
            }

            # Prediction is argmax
            prediction = ['A', 'B', 'C', 'D'][int(jnp.argmax(answer_probs))]

            config_results.append({
                'sample_id': prep['sample']['_id'],
                'ground_truth': prep['sample']['answer'],
                'prediction': prediction,
                'correct': prediction == prep['sample']['answer'],
                'probabilities': answer_probs_dict,
            })

        # Compute accuracy
        num_correct = sum(r['correct'] for r in config_results)
        accuracy = num_correct / len(config_results) if config_results else 0

        print(f"\nResults for {fma_config['name']}:")
        print(f"  Accuracy: {num_correct}/{len(config_results)} = {accuracy:.1%}")

        all_results[fma_config['name']] = {
            'config': fma_config,
            'accuracy': accuracy,
            'num_correct': num_correct,
            'num_total': len(config_results),
            'per_sample_results': config_results,
        }

    # Save results
    print(f"\n{'='*70}")
    print("Saving results...")
    print(f"{'='*70}")

    output_data = {
        'checkpoint': checkpoint_dir,
        'evaluation_settings': {
            'min_length': min_length,
            'max_length': max_length,
            'target_padded_length': max_length,
            'num_samples': len(prepared_samples),
        },
        'configs': all_results,
    }

    with open(output_file, 'w') as f:
        json.dump(output_data, f, indent=2)

    print(f"Results saved to: {output_file}")

    # Print summary
    print(f"\n{'='*70}")
    print("FINAL SUMMARY")
    print(f"{'='*70}")
    for config_name, result in all_results.items():
        print(f"{config_name:30s}: {result['accuracy']:.1%} ({result['num_correct']}/{result['num_total']})")

    print(f"\n{'='*70}")
    print("Evaluation complete!")
    print(f"{'='*70}")


if __name__ == '__main__':
    import sys

    if len(sys.argv) > 1:
        if sys.argv[1] == '--test-bucketing':
            test_bucketing()
        elif sys.argv[1] == '--test-padding':
            test_cyclic_padding()
        elif sys.argv[1] == '--test-model':
            test_model_loading()
        elif sys.argv[1] == '--test-logits':
            # Parse arguments: --test-logits [N] [checkpoint_dir]
            n_samples = int(sys.argv[2]) if len(sys.argv) > 2 else 5
            checkpoint = sys.argv[3] if len(sys.argv) > 3 else 'checkpoints/llama-3.2-1b-flax'
            test_logit_extraction(max_samples=n_samples, checkpoint_dir=checkpoint)
        elif sys.argv[1] == '--evaluate':
            # Full evaluation with argparse for cleaner option handling
            import argparse
            parser = argparse.ArgumentParser(
                prog='evaluate_longbench_v2.py --evaluate',
                description='Run full LongBench v2 evaluation with FMA config comparison'
            )
            parser.add_argument('checkpoint_dir', help='Path to model checkpoint')
            parser.add_argument('--output', '-o', default=None,
                              help='Output JSON file path (default: auto-generated)')
            parser.add_argument('--octave', type=int, default=15,
                              help='Power of 2 for max_length (default: 15 → [16K, 32K) range). '
                                   'Examples: 14=[8K,16K), 15=[16K,32K), 16=[32K,64K)')
            parser.add_argument('--max-samples', type=int, default=10,
                              help='Maximum samples to evaluate per config (default: 10)')

            args = parser.parse_args(sys.argv[2:])
            evaluate_longbench_v2(
                args.checkpoint_dir,
                output_file=args.output,
                octave=args.octave,
                max_samples=args.max_samples
            )
        else:
            print("Unknown argument.")
            print("Usage:")
            print("  --test-bucketing")
            print("  --test-padding")
            print("  --test-model")
            print("  --test-logits [N] [checkpoint_dir]")
            print("    N = number of samples (default 5)")
            print("    checkpoint_dir = path to checkpoint (default: checkpoints/llama-3.2-1b-flax)")
            print("  --evaluate checkpoint_dir [OPTIONS]")
            print("    Run full evaluation with all FMA configs")
            print("    Options:")
            print("      --output FILE, -o FILE    JSON output path (default: auto-generated)")
            print("      --octave N                Power of 2 for sequence length (default: 15)")
            print("                                octave=13 → [4K, 8K), octave=14 → [8K, 16K)")
            print("                                octave=15 → [16K, 32K), octave=16 → [32K, 64K)")
            print("      --max-samples N           Samples per config (default: 10)")
    else:
        # Default: run component 1 test
        test_prompt_construction()
