import string
from pathlib import Path

import numpy as np
import torch
from transformers import (
    BatchEncoding,
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from transformers.utils import ModelOutput

from lib_llm.eval.metrics.token_level import sequence_token_probs
from lib_llm.inference import PredictionConfig, predict, tokenize


class DummyTokenizer(PreTrainedTokenizer):
    """A dummy tokenizer for testing that produces constant output."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, text: list[str], **kwargs):
        num_sequences = len(text)
        num_tokens = len(text[0])
        return BatchEncoding(
            {
                "input_ids": torch.ones(
                    (num_sequences, num_tokens), dtype=torch.long
                ),
                "attention_mask": torch.ones(
                    (num_sequences, num_tokens), dtype=torch.long
                ),
            }
        )


class DummyModel(PreTrainedModel):
    """A dummy model for testing that produces constant output."""

    def __init__(self, config: PretrainedConfig):
        super().__init__(config)

    def forward(self, input_ids, attention_mask):
        return ModelOutput(logits=torch.ones((*input_ids.shape, 2)))


def create_saved_model() -> Path:
    """Create and save a model to demonstrate how to load it for inference."""
    config = PretrainedConfig(name_or_path="dummy_model")
    model = DummyModel(config)
    tokenizer = DummyTokenizer()

    script_path = Path(__file__).parent.resolve()
    model_save_path = script_path / "dummy_model"
    model.save_pretrained(str(model_save_path))
    tokenizer.save_pretrained(str(model_save_path))

    return model_save_path


def generate_random_strings() -> list[str]:
    num_sequences = 20
    sequence_length = 10

    rng = np.random.default_rng(seed=520)
    letters = string.ascii_lowercase
    sequences = rng.choice(list(letters), size=(num_sequences, sequence_length))
    return ["".join(seq) for seq in sequences]


def test_compute_sequence_probabilities():
    """An example showcasing how the code in this library can be used
    for inference and probablity computation.
    """
    # Create and load a dummy model for inference
    # model_save_path = create_saved_model()
    # model, tokenizer = load_model_tokenizer(model_save_path)
    model = DummyModel(PretrainedConfig(name_or_path="dummy_model"))
    tokenizer = DummyTokenizer()
    tokenizer.padding_side = "left"

    sequences = generate_random_strings()
    # Convert the text to a BatchEncoding object. Using the tokenize
    # function ensures that the encoding format looks as expected, esp.
    # that left-padding is used.
    encoded_sequences = tokenize(tokenizer, sequences)
    config = PredictionConfig(
        trim_last_token=True,
        # For actual models you should omit the batch size and let the
        # inerence code do its magic
        batch_size=5,
    )
    # Pass the encoded sequences to the model for inference.
    # The predict() function will rebatch and compress the batches if
    # possible for more efficient inference.
    model_output = predict(
        model,
        encoded_sequences,
        config,
    )
    logits = model_output["logits"]
    num_sequences, sequence_length = encoded_sequences.input_ids.shape
    assert logits.shape == (num_sequences, sequence_length - 1, 2)

    # Compute the log probability for each token
    # probs = compute_token_probs(logits, encoded_sequences)
    # # From the full set of probabilities over all tokens, extract
    # # only those corresponding to the sequence tokens
    # token_probs = sequence_token_probs(probs, encoded_sequences)
    # assert len(probs.values) == num_sequences
    # assert all(tp.shape == (sequence_length - 1,) for tp in token_probs.masked)
