import pytest
import torch
from transformers import DynamicCache

from src.io_utils import load_model_and_tokenizer
from src.lm_utils import generate_ragged_batched, prepare_conversation


@pytest.fixture
def model_and_tokenizer():
    """Fixture providing a tokenizer for testing."""
    class ModelConfig:
        id = "meta-llama/Meta-Llama-3-8B-Instruct"
        tokenizer_id = "meta-llama/Meta-Llama-3-8B-Instruct"
        short_name = "Llama"
        developer_name = "Meta"
        compile = False
        dtype = "bfloat16"
        chat_template = None
        trust_remote_code = True

    model, tokenizer = load_model_and_tokenizer(ModelConfig())
    return model, tokenizer


def longest_common_prefix_length(str1: str, str2: str) -> int:
    """Return the length of the longest common prefix between two strings."""
    min_len = min(len(str1), len(str2))
    for i in range(min_len):
        if str1[i] != str2[i]:
            print(i, f"{str1[i]!r}", f"{str2[i]!r}")
            return i
    return min_len


def test_generate_ragged_batched_greedy_no_batch(model_and_tokenizer):
    """Compares the output of generate_ragged_batched with the output of model.generate and a manual loop"""
    model, tokenizer = model_and_tokenizer
    conversation = [
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": ""},
    ]
    tokens = torch.cat(prepare_conversation(tokenizer, conversation)[0]) # (L,)
    tokens = torch.tensor(tokens, device=model.device).unsqueeze(0) # (1, L)

    max_new_tokens = 256

    # First, generate with a manual loop, no kv cache
    generate_tokens = tokens.clone()
    for i in range(max_new_tokens):
        logits = model(generate_tokens).logits[:, -1]
        next_token = torch.argmax(logits, dim=-1)
        if next_token == tokenizer.eos_token_id:
            break
        generate_tokens = torch.cat([generate_tokens, next_token.unsqueeze(0)], dim=1)
    reference = tokenizer.decode(generate_tokens[0,tokens.size(1):])


    # Generate with generate_ragged_batched
    ragged_generated = generate_ragged_batched(
        model, tokenizer, [tokens[0]], max_new_tokens=max_new_tokens,
        num_return_sequences=1, temperature=0.0
    )[0][0]

    # Generate with model.generate
    generated = model.generate(tokens, do_sample=False, max_new_tokens=max_new_tokens)[0]
    generated = tokenizer.decode(generated[tokens.size(1):])

    ragged_vs_reference = longest_common_prefix_length(reference, ragged_generated) / max(len(reference), len(ragged_generated))
    generate_vs_reference = longest_common_prefix_length(reference, generated) / max(len(reference), len(generated))
    print(f"Ragged vs reference: {ragged_vs_reference:.2f}")
    print(f"Generate vs reference: {generate_vs_reference:.2f}")
    assert ragged_vs_reference == 1.0


def test_generate_ragged_batched_parallel_identical(model_and_tokenizer):
    """Test that generating 16 identical inputs in parallel produces the same output as a single reference"""
    model, tokenizer = model_and_tokenizer
    conversation = [
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": ""},
    ]
    tokens = torch.cat(prepare_conversation(tokenizer, conversation)[0])
    tokens = torch.tensor(tokens, device=model.device)

    max_new_tokens = 512

    # First, generate with a manual loop, no kv cache
    generate_tokens = tokens.unsqueeze(0).expand(16, -1).clone() # (16, L)
    for i in range(max_new_tokens):
        logits = model(generate_tokens).logits[:, -1]
        next_token = torch.argmax(logits, dim=-1)
        if torch.all(next_token == tokenizer.eos_token_id):
            break
        generate_tokens = torch.cat([generate_tokens, next_token.unsqueeze(0)], dim=1)
    reference_output = tokenizer.batch_decode(generate_tokens[:, tokens.size(1):])

    # Create batch of 16 identical inputs
    batch_tokens = [tokens] * 16

    # Generate with batch of 16 identical inputs
    batch_outputs = generate_ragged_batched(
        model, tokenizer, batch_tokens, max_new_tokens=max_new_tokens,
        num_return_sequences=1, temperature=0.0
    )[0]

    # Verify all outputs match the reference
    for i, output in enumerate(batch_outputs):
        similarity = longest_common_prefix_length(reference_output[i], output) / max(len(reference_output[i]), len(output))
        print(f"Batch item {i} vs reference: {similarity:.2f}")
        assert similarity == 1.0, f"Batch item {i} differs from reference: {output[:50]}... vs {reference_output[i][:50]}..."


def test_generate_ragged_batched_parallel_identical_large_batch(model_and_tokenizer):
    """Test that generating 16 identical inputs in parallel produces the same output as a single reference"""
    model, tokenizer = model_and_tokenizer
    conversation = [
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": ""},
    ]
    tokens = torch.cat(prepare_conversation(tokenizer, conversation)[0])
    tokens = torch.tensor(tokens, device=model.device)

    max_new_tokens = 512

    # First, generate with a manual loop, no kv cache
    generate_tokens = tokens.unsqueeze(0).expand(512, -1).clone() # (512, L)
    for i in range(max_new_tokens):
        logits = model(generate_tokens).logits[:, -1]
        next_token = torch.argmax(logits, dim=-1)
        if torch.all(next_token == tokenizer.eos_token_id):
            break
        generate_tokens = torch.cat([generate_tokens, next_token.unsqueeze(0)], dim=1)
    reference_output = tokenizer.batch_decode(generate_tokens[:, tokens.size(1):])

    # Create batch of 16 identical inputs
    batch_tokens = [tokens] * 512

    # Generate with batch of 16 identical inputs
    batch_outputs = generate_ragged_batched(
        model, tokenizer, batch_tokens, max_new_tokens=max_new_tokens,
        num_return_sequences=1, temperature=0.0
    )[0]

    # Generate with model.generate
    generated = model.generate(tokens.unsqueeze(0).expand(512, -1).clone(), do_sample=False, max_new_tokens=max_new_tokens)
    generated = tokenizer.batch_decode(generated[:, tokens.size(0):])

    # Verify all outputs match the reference
    for i, output in enumerate(batch_outputs):
        similarity = longest_common_prefix_length(reference_output[i], output) / max(len(reference_output[i]), len(output))
        generated_similarity = longest_common_prefix_length(reference_output[i], generated[i]) / max(len(reference_output[i]), len(generated[i]))
        print(f"Batch item {i} vs reference: {similarity:.2f}, generated vs reference: {generated_similarity:.2f}")
        assert similarity == 1.0, f"Batch item {i} differs from reference: {output[:50]}... vs {reference_output[i][:50]}..."
