"""Run inference with Llama 3.2 1B using FMA attention."""

import argparse
import pickle
from pathlib import Path
import time

import jax
import jax.numpy as jnp
from transformers import AutoTokenizer

from fma_llama.model.config import LlamaConfig
from fma_llama.model.llama import LlamaForCausalLM


def load_model(checkpoint_dir: str):
    """Load Flax model and config from checkpoint directory."""
    checkpoint_path = Path(checkpoint_dir)

    # Load config
    with open(checkpoint_path / "config.pkl", 'rb') as f:
        config = pickle.load(f)

    # Load parameters
    with open(checkpoint_path / "flax_params.pkl", 'rb') as f:
        params = pickle.load(f)

    # Create model
    model = LlamaForCausalLM(config)

    return model, params, config


def generate(
    model,
    params,
    input_ids: jnp.ndarray,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int = 50,
):
    """Generate text using the model.

    Args:
        model: Flax model
        params: Model parameters
        input_ids: Input token IDs of shape (batch_size, seq_len)
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature
        top_k: Top-k sampling parameter

    Returns:
        Generated token IDs
    """
    rng = jax.random.PRNGKey(0)

    generated_ids = input_ids

    for _ in range(max_new_tokens):
        # Get logits for next token
        logits = model.apply(params, generated_ids)

        # Get logits for last token
        next_token_logits = logits[:, -1, :] / temperature

        # Top-k sampling
        if top_k > 0:
            top_k_logits, top_k_indices = jax.lax.top_k(next_token_logits, top_k)
            rng, sample_rng = jax.random.split(rng)
            next_token_probs = jax.nn.softmax(top_k_logits, axis=-1)
            next_token_idx = jax.random.categorical(sample_rng, jnp.log(next_token_probs))
            # next_token_idx has shape (1,), so we need to index properly
            next_token = top_k_indices[jnp.arange(1), next_token_idx]  # (1,)
        else:
            # Greedy sampling
            next_token = jnp.argmax(next_token_logits, axis=-1)  # (1,)

        # Append next token - reshape from (1,) to (1, 1)
        next_token = next_token[:, None]  # (1, 1)
        generated_ids = jnp.concatenate([generated_ids, next_token], axis=1)

    return generated_ids


def main():
    parser = argparse.ArgumentParser(description='Run inference with Llama 3.2 1B')
    parser.add_argument(
        '--checkpoint_dir',
        type=str,
        default='checkpoints/llama-3.2-1b-flax',
        help='Checkpoint directory',
    )
    parser.add_argument(
        '--prompt',
        type=str,
        default='Once upon a time',
        help='Input prompt',
    )
    parser.add_argument(
        '--max_new_tokens',
        type=int,
        default=50,
        help='Maximum number of tokens to generate',
    )
    parser.add_argument(
        '--temperature',
        type=float,
        default=1.0,
        help='Sampling temperature',
    )
    parser.add_argument(
        '--top_k',
        type=int,
        default=50,
        help='Top-k sampling parameter',
    )
    parser.add_argument(
        '--use_fma',
        action='store_true',
        help='Use FMA attention (default: False for initial testing)',
    )

    args = parser.parse_args()

    print(f"Loading model from {args.checkpoint_dir}")
    model, params, config = load_model(args.checkpoint_dir)

    # Set FMA attention flag
    config.use_fma_attention = args.use_fma
    print(f"Using FMA attention: {config.use_fma_attention}")

    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B')

    # Tokenize input
    print(f"\nPrompt: {args.prompt}")
    input_ids = tokenizer.encode(args.prompt, return_tensors='jax')
    print(f"Input tokens: {input_ids.shape}")

    # Generate
    print("\nGenerating...")
    start_time = time.time()
    generated_ids = generate(
        model,
        params,
        input_ids,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_k=args.top_k,
    )
    generation_time = time.time() - start_time

    # Decode output
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    print(f"\nGenerated text ({generation_time:.2f}s):")
    print("-" * 80)
    print(generated_text)
    print("-" * 80)

    tokens_per_second = args.max_new_tokens / generation_time
    print(f"\nPerformance: {tokens_per_second:.2f} tokens/second")


if __name__ == '__main__':
    main()
