from typing import Callable, Dict

import pandas as pd
import torch
from transformers import (
    BatchEncoding,
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from transformers.utils import ModelOutput

from lib_llm.eval.memorization.prefix_mappings.eval import (
    PredictionConfig,
    PrefixEvalConfig,
    _aggregate_predictions,
    _compute_plurality_prediction,
    _eval_token,
    _get_prefix_lengths,
    _get_probing_token_indices,
    _is_prediction_correct,
    eval_prefix_lengths,
)


class DummyModel(PreTrainedModel):
    """A dummy model for testing that produces constant output."""

    def __init__(
        self,
        config: PretrainedConfig,
        prediction: int,
        expected_input_id_calls: list[torch.Tensor],
    ):
        super().__init__(config)
        self.prediction = prediction
        self.expected_input_id_iter = iter(expected_input_id_calls)

    def forward(self, input_ids, attention_mask):
        print("called")
        print("input_ids", input_ids)
        expected_input_ids = next(self.expected_input_id_iter)
        assert torch.equal(input_ids, expected_input_ids)

        logits = torch.zeros((*input_ids.shape, self.prediction + 1))
        logits[:, :, self.prediction] = 1.0
        return ModelOutput(logits=logits)

    def to(self, device):
        return self


class DummyTokenizer(PreTrainedTokenizer):
    """A dummy tokenizer for testing that produces constant output."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.add_special_tokens({"pad_token": "<pad>"})

    def __call__(self, text: list[str], **kwargs):
        num_sequences = len(text)
        num_tokens = len(text[0])
        encoding = BatchEncoding(
            {
                "input_ids": torch.ones(
                    (num_sequences, num_tokens), dtype=torch.long
                ),
                "attention_mask": torch.ones(
                    (num_sequences, num_tokens), dtype=torch.long
                ),
            }
        )
        return encoding

    def _convert_id_to_token(self, index: int) -> str:
        return "a"

    def get_vocab(self) -> Dict[str, int]:
        return {char: i for i, char in enumerate("abcdefghijklmnopqrstuvwxyz")}


def test_get_probing_token_indices():
    sequence_length = 10
    max_token_samples = 5
    seed = 42
    indices = _get_probing_token_indices(
        sequence_length, max_token_samples, seed
    )
    assert len(indices) == max_token_samples
    assert all(1 <= idx < sequence_length for idx in indices)


def test_get_prefix_lengths():
    token_idx = 5
    prefix_lengths = [1, 2, 3, 4, 10]
    result = _get_prefix_lengths(token_idx, prefix_lengths)
    assert result == [1, 2, 3, 4, 5]


def test_get_prefix_lengths_2():
    token_idx = 3
    prefix_lengths = [1, 2, 4, 8]
    result = _get_prefix_lengths(token_idx, prefix_lengths)
    assert result == [1, 2, 3]


def test_is_prediction_correct():
    logits = torch.tensor(
        [[0.1, 0.2, 0.3, 0.4], [0.6, 0.1, 0.2, 0.0]],
    )
    target_token_ids = torch.tensor([3, 3])
    result = _is_prediction_correct(logits, target_token_ids)
    assert isinstance(result, torch.Tensor)
    assert len(result) == len(logits)
    assert all(item in [True, False] for item in result)


def test_compute_plurality_prediction():
    logits = torch.tensor(
        [
            [0.1, 0.2, 0.3, 0.4],
            [0.6, 0.1, 0.2, 0.0],
            [0.3, 0.2, 0.1, 0.4],
        ],
    )
    plurality_prediction = _compute_plurality_prediction(logits)
    assert plurality_prediction == torch.tensor(3)


def test_aggregate_predictions():
    logits = torch.tensor(
        [
            [0.1, 0.2, 0.3, 0.4],
            [0.0, 0.1, 0.2, 0.6],
            [0.1, 0.2, 0.3, 0.4],
            [0.6, 0.1, 0.2, 0.0],
            [0.6, 0.1, 0.2, 0.0],
        ],
    )
    batch_starts = [0, 2]
    target_token_id = torch.tensor(3)
    result = _aggregate_predictions(
        logits,
        batch_starts,
        target_token_id,
    )
    assert isinstance(result, pd.DataFrame)
    assert len(result) == len(batch_starts)
    assert (result["correct_samples"] == [1.0, 1 / 3]).all()
    assert (result["plurality_correct"] == [True, False]).all()


def test_eval_token():
    sequence_token_ids = torch.tensor([1, 2, 3, 4, 5])
    target_token_idx = 3
    target_token_id = sequence_token_ids[target_token_idx]
    num_samples_per_prefix = 2
    prefix_lengths = [1, 2, 4, 8]

    model = DummyModel(
        PretrainedConfig(name_or_path="dummy_model"),
        prediction=int(target_token_id.item()),
        expected_input_id_calls=[
            torch.tensor(
                [
                    [0, 0, 3],
                    [0, 2, 3],
                    [1, 2, 3],
                ]
            ),
        ],
    )
    tokenizer = DummyTokenizer()
    prediction_config = PredictionConfig(
        trim_last_token=False,
        batch_size=4,
    )

    def get_replacements(
        token_ids: torch.Tensor, target_length: int
    ) -> torch.Tensor:
        assert token_ids.ndim == 1
        assert 0 < target_length <= target_token_idx
        return torch.zeros((num_samples_per_prefix, target_length))

    result = _eval_token(
        model,
        tokenizer,
        prediction_config,
        sequence_token_ids,
        get_replacements,
        target_token_idx,
        target_token_id,
        relative_size=1.0,
        prefix_lengths=prefix_lengths,
    )
    expected = pd.DataFrame(
        {
            "correct_samples": [1.0, 1.0, 1.0],
            "plurality_correct": [True, True, True],
        },
        index=pd.Index([1, 2, 3]),
    )
    assert result.equals(expected)
    try:
        next(model.expected_input_id_iter)
        assert False, "Expected StopIteration"
    except StopIteration:
        pass


# TODO: the inference code currently stops adding elements to the batch when
# it's full, even if the new elements are the same as the already added oens.
# Optimmize that and then re-enable this test.
# def test_eval_token_two_batches():
#     sequence_token_ids = torch.tensor([1, 2, 3, 4, 5])
#     target_token_idx = 3
#     target_token_id = sequence_token_ids[target_token_idx]
#     num_samples_per_prefix = 2
#     prefix_lengths = [1, 2, 4, 8]

#     model = DummyModel(
#         PretrainedConfig(name_or_path="dummy_model"),
#         prediction=int(target_token_id.item()),
#         expected_input_id_calls=[
#             torch.tensor(
#                 [
#                     [0, 0, 3],
#                     [0, 2, 3],
#                 ]
#             ),
#             torch.tensor(
#                 [
#                     [1, 2, 3],
#                 ]
#             ),
#         ],
#     )
#     tokenizer = DummyTokenizer()
#     prediction_config = PredictionConfig(
#         trim_last_token=False,
#         batch_size=2,  # The three inputs need to be split into two batches
#     )

#     def get_replacements(
#         token_ids: torch.Tensor, target_length: int
#     ) -> torch.Tensor:
#         assert token_ids.ndim == 1
#         assert 0 < target_length <= target_token_idx
#         return torch.zeros((num_samples_per_prefix, target_length))

#     result = _eval_token(
#         model,
#         tokenizer,
#         prediction_config,
#         sequence_token_ids,
#         get_replacements,
#         target_token_idx,
#         target_token_id,
#         relative_size=1.0,
#         prefix_lengths=prefix_lengths,
#     )
#     expected = pd.DataFrame(
#         {
#             "correct_samples": [1.0, 1.0, 1.0],
#             "plurality_correct": [True, True, True],
#         },
#         index=pd.Index([1, 2, 3]),
#     )
#     assert result.equals(expected)
#     try:
#         next(model.expected_input_id_iter)
#         assert False, "Expected StopIteration"
#     except StopIteration:
#         pass


def test_eval_prefix_lengths():
    sequence_token_ids = torch.tensor([1, 2, 3, 4, 5])
    num_samples_per_prefix = 2
    config = PrefixEvalConfig(
        seed=522,
        prefix_lengths=[1, 2, 4, 8],
        num_samples_per_prefix=num_samples_per_prefix,
        max_token_samples=4,
    )
    model = DummyModel(
        PretrainedConfig(name_or_path="dummy_model"),
        prediction=4,
        expected_input_id_calls=[
            torch.tensor(
                [
                    [1],
                ]
            ),
            torch.tensor(
                [
                    [0, 2],
                    [1, 2],
                ]
            ),
            torch.tensor(
                [
                    [0, 0, 3],
                    [0, 2, 3],
                    [1, 2, 3],
                ]
            ),
            torch.tensor(
                [
                    [0, 0, 0, 4],
                    [0, 0, 3, 4],
                    [1, 2, 3, 4],
                ]
            ),
        ],
    )
    model.name_or_path = "pythia-70m"
    tokenizer = DummyTokenizer()

    def get_replacements(
        token_ids: torch.Tensor, target_length: int
    ) -> torch.Tensor:
        assert token_ids.ndim == 1
        assert 0 < target_length <= 5
        return torch.zeros((num_samples_per_prefix, target_length))

    result = eval_prefix_lengths(
        config,
        model,
        tokenizer,
        sequence_token_ids,
        get_replacements,
    )
    expected = pd.DataFrame(
        {
            "correct_samples": [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
            "plurality_correct": [
                False,
                False,
                False,
                True,
                True,
                True,
                False,
                False,
                False,
            ],
        },
        index=pd.MultiIndex.from_tuples(
            [
                (1, 1),
                (2, 1),
                (2, 2),
                (3, 1),
                (3, 2),
                (3, 3),
                (4, 1),
                (4, 2),
                (4, 4),
            ],
            names=["token_idx", "prefix_length"],
        ),
    )
    assert result.equals(expected)
    try:
        next(model.expected_input_id_iter)
        assert False, "Expected StopIteration"
    except StopIteration:
        pass
