from typing import cast

import numpy as np
import torch
from transformers import BatchEncoding, PreTrainedTokenizer

from utils.prefix_mappings import PrefixEvalConfig, setup_replacements


def test_setup_replacements():
    tokenizer = cast(PreTrainedTokenizer, DummyTokenizer())
    for replacement_strategy in [
        "rand_id",
        "const_id",
        "rand_ood",
        "const_ood",
    ]:
        replacement_length = 8
        get_replacements = setup_replacements(
            PrefixEvalConfig(
                seed=4583,
                num_samples_per_prefix=4,
            ),
            replacement_strategy=replacement_strategy,
            tokenizer=tokenizer,
            tokenizer_type="pythia",
            replacement_length=replacement_length,
        )

        replacements = get_replacements(torch.empty((0, 0)), replacement_length)
        assert replacements.shape == (4, replacement_length)
        for i, replacement in enumerate(replacements):
            for other_replacement in replacements[i + 1 :]:
                assert not torch.all(replacement == other_replacement)


class DummyTokenizer:
    all_special_ids = []
    vocab_size = 256

    def __call__(
        self,
        input_strings: list[str],
        return_tensors: str,
    ) -> BatchEncoding:
        token_ids = np.array(
            [[ord(x) for x in input_string] for input_string in input_strings]
        )
        return BatchEncoding(
            data={"input_ids": token_ids},
        )

    def convert_ids_to_tokens(self, token_ids: list[int]) -> list[str]:
        return [chr(x) for x in token_ids]
