import torch

from experiments.practical_memorization_dynamics.experiment import (
    _get_insertion_pos,
    _insert_random_string,
)


def test_get_insertion_pos_bos_true():
    rand_input_ids = torch.tensor([1, 2, 3])
    context_length = 10
    uses_bos_token = True

    (
        insertion_margin,
        first_insertion_token,
        rand_input_ids,
    ) = _get_insertion_pos(rand_input_ids, uses_bos_token, context_length)

    assert insertion_margin == 6
    assert first_insertion_token == 1
    assert torch.all(rand_input_ids == torch.tensor([1, 2, 3]))


def test_get_insertion_pos_bos_false():
    rand_input_ids = torch.tensor([1, 2, 3])
    context_length = 10
    uses_bos_token = False

    (
        insertion_margin,
        first_insertion_token,
        rand_input_ids,
    ) = _get_insertion_pos(rand_input_ids, uses_bos_token, context_length)

    assert insertion_margin == 7
    assert first_insertion_token == 0
    assert torch.all(rand_input_ids == torch.tensor([1, 2, 3]))


def test_get_insertion_pos_equal_length_bos_true():
    rand_input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    context_length = 10
    uses_bos_token = True

    (
        insertion_margin,
        first_insertion_token,
        rand_input_ids,
    ) = _get_insertion_pos(rand_input_ids, uses_bos_token, context_length)

    assert insertion_margin == 0
    assert first_insertion_token == 1
    assert torch.all(
        rand_input_ids == torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
    )


def test_get_insertion_pos_equal_length_bos_false():
    rand_input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    context_length = 10
    uses_bos_token = False

    (
        insertion_margin,
        first_insertion_token,
        rand_input_ids,
    ) = _get_insertion_pos(rand_input_ids, uses_bos_token, context_length)

    assert insertion_margin == 0
    assert first_insertion_token == 0
    assert torch.all(
        rand_input_ids == torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    )


class MockRandomGenerator:
    def integers(self, low, high):
        return int(low + (high - low) / 2)


def test_insert_random_string_margin_gt_first_token():
    mock_rng = MockRandomGenerator()
    rand_input_ids = torch.tensor([99, 98])
    batch_size = 2
    insertion_margin = 5
    first_insertion_token = 1

    examples = {
        "input_ids": [
            torch.tensor([1, 2, 3, 4, 5, 6]),
            torch.tensor([7, 8, 9, 10, 11, 12]),
            torch.tensor([13, 14, 15, 16, 17, 18]),
        ]
    }

    expected_output = {
        "input_ids": [
            torch.tensor([1, 2, 3, 4, 5, 6]),
            torch.tensor([7, 8, 9, 99, 98, 12]),
            torch.tensor([13, 14, 15, 16, 17, 18]),
        ]
    }

    result = _insert_random_string(
        examples,
        rng=mock_rng,
        rand_input_ids=rand_input_ids,
        batch_size=batch_size,
        insertion_margin=insertion_margin,
        first_insertion_token=first_insertion_token,
    )

    for i in range(len(result["input_ids"])):
        torch.all(result["input_ids"][i] == expected_output["input_ids"][i])


def test_insert_random_string_margin_lt_first_token():
    mock_rng = MockRandomGenerator()
    rand_input_ids = torch.tensor([99, 98, 97, 96, 95, 94])
    batch_size = 3
    insertion_margin = 0
    first_insertion_token = 0

    examples = {
        "input_ids": [
            torch.tensor([1, 2, 3, 4, 5, 6]),
            torch.tensor([7, 8, 9, 10, 11, 12]),
            torch.tensor([13, 14, 15, 16, 17, 18]),
        ]
    }

    expected_output = {
        "input_ids": [
            torch.tensor([1, 2, 3, 4, 5, 6]),
            torch.tensor([99, 98, 97, 96, 95, 94]),
            torch.tensor([13, 14, 15, 16, 17, 18]),
        ]
    }

    result = _insert_random_string(
        examples,
        rng=mock_rng,
        rand_input_ids=rand_input_ids,
        batch_size=batch_size,
        insertion_margin=insertion_margin,
        first_insertion_token=first_insertion_token,
    )

    for i in range(len(result["input_ids"])):
        torch.all(result["input_ids"][i] == expected_output["input_ids"][i])
