from collections.abc import Iterator

import torch
from transformers import AutoTokenizer, BatchEncoding
from transformers.utils import ModelOutput

from lib_llm.inference._batch_mapping import (
    map_to_model_batches,
    remap_to_input_batches,
)
from lib_llm.inference.preprocess import tokenize


def _check_model_batches(
    model_batch_iter: Iterator,
    expected_model_batch_sizes: list[int],
    expected_idx_mappings: list[list[int]],
    expected_sequence_length: int,
) -> None:
    for expected_size, expected_mappings in zip(
        expected_model_batch_sizes, expected_idx_mappings
    ):
        model_batches, idx_mappings = next(model_batch_iter)
        assert model_batches.input_ids.shape == (
            expected_size,
            expected_sequence_length,
        )
        assert all(idx < expected_size for idx in idx_mappings)
        assert (
            idx_mappings == expected_mappings
        ), f"Expected idx mappings {expected_mappings} but got {idx_mappings}"
    try:
        next(model_batch_iter)
    except StopIteration:
        pass
    else:
        raise AssertionError("Expected StopIteration exception")


def test_map_to_model_batches_without_duplicates():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Encode 4 batches of size 32
    texts = [f"{i}: This is a test sentence." for i in range(128)]
    tokenized_batches = [
        tokenize(tokenizer, texts[i : i + 32]) for i in range(0, len(texts), 32)
    ]
    # torch.manual_seed(853)
    # tokenized_batches = [
    #     torch.randint(100, 1000, (32, 20)),
    #     for _ in range(4)
    # ]

    model_batch_size = 100
    model_batch_iter = map_to_model_batches(tokenized_batches, model_batch_size)
    _check_model_batches(
        model_batch_iter,
        [100, 28],
        [
            [
                *range(32),
                -1,
                *range(32, 64),
                -1,
                *range(64, 96),
                -1,
                *range(96, 100),
            ],
            [*range(28), -1],
        ],
        len(tokenized_batches[0].input_ids[0]) - 1,
    )


def test_map_to_model_batches_with_duplicates():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Encode 5 batches of size 32, but with one batch being redundant sentences
    # that should get merged with the other sentences
    duplicated_texts = [
        *["This is a test"] * 16,
        *[f"{i}: This is a test sentence." for i in range(128)],
        *["SANONYMOUS a test"] * 16,
    ]
    tokenized_batches = [
        tokenize(tokenizer, duplicated_texts[i : i + 32])
        for i in range(0, len(duplicated_texts), 32)
    ]

    model_batch_size = 50
    model_batch_iter = map_to_model_batches(tokenized_batches, model_batch_size)
    _check_model_batches(
        model_batch_iter,
        [50, 50, 30],
        [
            [*([0] * 16), *range(1, 17), -1, *range(17, 49), -1, 49],
            [*range(31), -1, *range(31, 50)],
            [*range(13), -1, *range(13, 29), *([29] * 16), -1],
        ],
        len(tokenized_batches[0].input_ids[0]) - 1,
    )


def test_iter_larger_input_batches_than_model_batches():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Encode 4 batches of size 32
    texts = [f"{i}: This is a test sentence." for i in range(128)]
    tokenized_batches = [
        tokenize(tokenizer, texts[i : i + 64]) for i in range(0, len(texts), 64)
    ]

    model_batch_size = 50
    model_batch_iter = map_to_model_batches(tokenized_batches, model_batch_size)
    _check_model_batches(
        model_batch_iter,
        [50, 50, 28],
        [
            [*range(50)],
            [*range(14), -1, *range(14, 50)],
            [*range(28), -1],
        ],
        len(tokenized_batches[0].input_ids[0]) - 1,
    )


def test_map_to_model_batches_with_near_duplicates_w_trimming():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Encode identical texts but change the last token
    # All sentences should sANONYMOUS be collapsed into just one model input.
    texts = ["This is a test sentence."] * 32
    tokenized_batches = [
        tokenize(tokenizer, texts[i : i + 8]) for i in range(0, len(texts), 8)
    ]
    num_tokens = -1
    i = 0
    for batch in tokenized_batches:
        for input_ids in batch.input_ids:
            input_ids[-1] = i
            i += 1
            num_tokens = len(input_ids)

    model_batch_size = 20
    model_batch_iter = map_to_model_batches(
        tokenized_batches,
        model_batch_size,
        trim_last_column=True,
    )
    _check_model_batches(
        model_batch_iter,
        [1],
        [[*([0] * 8), -1, *([0] * 8), -1, *([0] * 8), -1, *([0] * 8), -1]],
        num_tokens - 1,
    )


def test_iter_model_input_batches_same_size():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Encode identical texts but change the last token
    # All sentences should sANONYMOUS be collapsed into just one model input.
    texts = ["This is a test sentence."] * 32
    tokenized_batches = [
        tokenize(tokenizer, texts[i : i + 16]) for i in range(0, len(texts), 16)
    ]
    num_tokens = -1
    i = 0
    for batch in tokenized_batches:
        for input_ids in batch.input_ids:
            input_ids[-1] = i
            i += 1
            num_tokens = len(input_ids)

    model_batch_size = 16
    model_batch_iter = map_to_model_batches(
        tokenized_batches,
        model_batch_size,
        trim_last_column=False,
    )
    _check_model_batches(
        model_batch_iter,
        [16, 16],
        [
            [*range(16), -1],
            [*range(16), -1],
        ],
        num_tokens,
    )


def test_map_to_model_batches_with_near_duplicates_wo_trimming():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Encode identical texts but change the last token
    # All sentences should sANONYMOUS be collapsed into just one model input.
    texts = ["This is a test sentence."] * 32
    tokenized_batches = [
        tokenize(tokenizer, texts[i : i + 8]) for i in range(0, len(texts), 8)
    ]
    num_tokens = -1
    i = 0
    for batch in tokenized_batches:
        for input_ids in batch.input_ids:
            input_ids[-1] = i
            i += 1
            num_tokens = len(input_ids)

    model_batch_size = 20
    model_batch_iter = map_to_model_batches(
        tokenized_batches,
        model_batch_size,
        trim_last_column=False,
    )
    _check_model_batches(
        model_batch_iter,
        [20, 12],
        [
            [*range(8), -1, *range(8, 16), -1, *range(16, 20)],
            [*range(4), -1, *range(4, 12), -1],
        ],
        num_tokens,
    )


def test_map_to_model_batches_no_duplicates():
    token_ids = torch.tensor(
        [
            [0, 0, 3],
            [0, 2, 3],
            [1, 2, 3],
        ]
    )
    batch_encoding = BatchEncoding(
        {
            "input_ids": token_ids,
            "attention_mask": torch.ones(token_ids.shape),
        }
    )
    model_batch_size = 10
    model_batch_iter = map_to_model_batches(
        [batch_encoding],
        model_batch_size,
        trim_last_column=False,
    )
    _check_model_batches(
        model_batch_iter,
        [3],
        [[*range(3), -1]],
        3,
    )

    model_batch_iter = map_to_model_batches(
        [batch_encoding],
        model_batch_size,
        trim_last_column=False,
    )
    batch = next(model_batch_iter)
    assert torch.equal(batch[0].input_ids, token_ids)
    try:
        next(model_batch_iter)
        assert False, "Expected StopIteration exception"
    except StopIteration:
        pass


def test_remap_to_input_batches_wo_duplicates():
    torch.manual_seed(224)
    input_batch_sizes = [25, 10]
    model_batch_sizes = [15, 15, 5]
    model_batch_indices = [
        [*range(15)],
        [*range(10), -1, *range(10, 15)],
        [*range(5), -1],
    ]
    logits = torch.randn(35, 5, 64)
    expected_batches = [
        logits[:25],
        logits[25:35],
    ]

    output_batch_generator = remap_to_input_batches()
    output_batches = []
    i = 0
    for batch_size, batch_indices in zip(
        model_batch_sizes, model_batch_indices
    ):
        batch_logits = logits[i : i + batch_size]
        for output_batch in output_batch_generator(
            ModelOutput(logits=batch_logits),
            batch_indices,
        ):
            output_batches.append(output_batch)
        i += batch_size

    assert len(output_batches) == 2
    for batch_size, batch_output, expected_batch_logits in zip(
        input_batch_sizes, output_batches, expected_batches
    ):
        batch_logits = batch_output.logits
        assert batch_logits.shape == (batch_size, 5, 64)
        torch.testing.assert_close(batch_logits, expected_batch_logits)
        # for i1 in range(batch_size):
        #     for i2 in range(5):
        #         for i3 in range(64):
        #             assert (
        #                 batch_logits[i1, i2, i3]
        #                 == expected_batch_logits[i1, i2, i3]
        #             ), f"i1={i1}, i2={i2}, i3={i3}"


def test_remap_to_input_batches_with_duplicates():
    torch.manual_seed(853)
    input_batch_sizes = [25, 25, 25]
    model_batch_sizes = [16, 16, 8]
    model_batch_indices = [
        [*range(16)],
        [*range(9), -1, *([9] * 15), *range(10, 16)],
        [*range(4), -1, *range(4, 7), *([7] * 22), -1],
    ]
    logits = torch.randn(40, 5, 64)
    expected_batches = [
        logits[:25],
        torch.cat([*([logits[25:26]] * 15), logits[26:36]], dim=0),
        torch.cat([logits[36:39], *([logits[39:40]] * 22)], dim=0),
    ]

    output_batch_generator = remap_to_input_batches()
    output_batches = []
    i = 0
    for batch_size, batch_indices in zip(
        model_batch_sizes, model_batch_indices
    ):
        batch_logits = logits[i : i + batch_size]
        for output_batch in output_batch_generator(
            ModelOutput(logits=batch_logits), batch_indices
        ):
            output_batches.append(output_batch)
        i += batch_size

    assert len(output_batches) == 3
    for batch_size, batch_output, expected_batch_logits in zip(
        input_batch_sizes, output_batches, expected_batches
    ):
        batch_logits = batch_output.logits
        assert batch_logits.shape == (batch_size, 5, 64)
        torch.testing.assert_close(batch_logits, expected_batch_logits)
