"""Evaluate Llama 3.2 1B with FMA attention on PG-19 dataset."""

import argparse
import pickle
from pathlib import Path
import time
from typing import Dict

import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
from functools import partial

from fma_llama.model.config import LlamaConfig
from fma_llama.model.llama import LlamaForCausalLM


def load_model(checkpoint_dir: str):
    """Load Flax model and config from checkpoint directory."""
    checkpoint_path = Path(checkpoint_dir)

    with open(checkpoint_path / "config.pkl", 'rb') as f:
        config = pickle.load(f)

    with open(checkpoint_path / "flax_params.pkl", 'rb') as f:
        params = pickle.load(f)

    model = LlamaForCausalLM(config)

    return model, params, config


def compute_perplexity(
    model,
    params,
    input_ids: jnp.ndarray,
    labels: jnp.ndarray,
) -> float:
    """Compute perplexity on a batch.

    Args:
        model: Flax model
        params: Model parameters
        input_ids: Input token IDs of shape (batch_size, seq_len)
        labels: Target token IDs of shape (batch_size, seq_len)

    Returns:
        Perplexity value
    """
    # Get logits
    logits = model.apply(params, input_ids)

    # Compute cross-entropy loss
    # Shift logits and labels for language modeling
    shift_logits = logits[:, :-1, :]
    shift_labels = labels[:, 1:]

    # Compute loss
    loss = jnp.mean(
        jnp.sum(
            -jax.nn.log_softmax(shift_logits, axis=-1) * jax.nn.one_hot(shift_labels, shift_logits.shape[-1]),
            axis=-1
        )
    )

    perplexity = jnp.exp(loss)
    return perplexity


def prepare_pg19_data(
    tokenizer,
    max_length: int = 2048,
    num_samples: int = 100,
):
    """Prepare PG-19 dataset for evaluation.

    Args:
        tokenizer: HuggingFace tokenizer
        max_length: Maximum sequence length
        num_samples: Number of samples to evaluate

    Returns:
        List of tokenized samples
    """
    print("Loading PG-19 dataset...")
    dataset = load_dataset('pg19', split='test', trust_remote_code=True)

    samples = []
    for i, example in enumerate(tqdm(dataset, desc="Tokenizing")):
        if i >= num_samples:
            break

        text = example['text']

        # Tokenize
        tokens = tokenizer.encode(text, add_special_tokens=False)

        # Split into chunks of max_length
        for j in range(0, len(tokens), max_length):
            chunk = tokens[j:j + max_length]
            if len(chunk) == max_length:
                samples.append(chunk)

    return samples


def evaluate_attention_quality(
    model,
    params,
    samples,
    use_fma: bool = True,
) -> Dict[str, float]:
    """Evaluate model on samples.

    Args:
        model: Flax model
        params: Model parameters
        samples: List of tokenized samples
        use_fma: Whether to use FMA attention

    Returns:
        Dictionary of metrics
    """
    perplexities = []
    times = []

    jit_compute_perplexity = jax.jit(partial(compute_perplexity, model, params))

    for sample in tqdm(samples, desc=f"Evaluating (FMA={use_fma})"):
        input_ids = jnp.array([sample])

        start_time = time.time()
        #perplexity = compute_perplexity(model, params, input_ids, input_ids)
        perplexity = float(jit_compute_perplexity(input_ids, input_ids))
        elapsed = time.time() - start_time

        perplexities.append(perplexity)
        times.append(elapsed)

    avg_perplexity = np.mean(perplexities)
    avg_time = np.mean(times)

    return {
        'perplexity': avg_perplexity,
        'avg_time': avg_time,
        'tokens_per_second': len(samples[0]) / avg_time if avg_time > 0 else 0,
    }


def main():
    parser = argparse.ArgumentParser(description='Evaluate Llama 3.2 1B on PG-19')
    parser.add_argument(
        '--checkpoint_dir',
        type=str,
        default='checkpoints/llama-3.2-1b-flax',
        help='Checkpoint directory',
    )
    parser.add_argument(
        '--max_length',
        type=int,
        default=2048,
        help='Maximum sequence length',
    )
    parser.add_argument(
        '--num_samples',
        type=int,
        default=100,
        help='Number of samples to evaluate',
    )
    parser.add_argument(
        '--compare_standard',
        action='store_true',
        help='Also evaluate with standard attention for comparison',
    )

    args = parser.parse_args()

    print(f"Loading model from {args.checkpoint_dir}")
    model, params, config = load_model(args.checkpoint_dir)

    # Load tokenizer (use Llama tokenizer for PG-19)
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B')

    # Prepare data
    samples = prepare_pg19_data(
        tokenizer,
        max_length=args.max_length,
        num_samples=args.num_samples,
    )
    print(f"Prepared {len(samples)} samples of length {args.max_length}")

    # Evaluate with standard attention (FMA not yet implemented)
    print("\n" + "=" * 80)
    print("Evaluating with standard attention")
    print("=" * 80)
    config.use_fma_attention = False
    metrics = evaluate_attention_quality(model, params, samples, use_fma=False)

    print(f"\nResults:")
    print(f"  Perplexity: {metrics['perplexity']:.4f}")
    print(f"  Avg time per sample: {metrics['avg_time']:.4f}s")
    print(f"  Tokens/second: {metrics['tokens_per_second']:.2f}")

    # TODO: Add FMA attention evaluation once implemented
    if args.compare_standard and False:  # Disabled until FMA is implemented
        print("\n" + "=" * 80)
        print("Evaluating with standard attention")
        print("=" * 80)
        config.use_fma_attention = False
        standard_metrics = evaluate_attention_quality(model, params, samples, use_fma=False)

        print(f"\nStandard Attention Results:")
        print(f"  Perplexity: {standard_metrics['perplexity']:.4f}")
        print(f"  Avg time per sample: {standard_metrics['avg_time']:.4f}s")
        print(f"  Tokens/second: {standard_metrics['tokens_per_second']:.2f}")

        # Comparison
        print("\n" + "=" * 80)
        print("Comparison")
        print("=" * 80)
        ppl_diff = fma_metrics['perplexity'] - standard_metrics['perplexity']
        speedup = standard_metrics['avg_time'] / fma_metrics['avg_time']

        print(f"Perplexity difference: {ppl_diff:+.4f} ({ppl_diff/standard_metrics['perplexity']*100:+.2f}%)")
        print(f"Speedup: {speedup:.2f}x")


if __name__ == '__main__':
    main()
