from typing import Dict

import torch
from transformers import BatchEncoding, PreTrainedTokenizer

from experiments.random_facts.experiment import (
    RandomFactsConfig,
    _insert_names,
    _split_on_placeholders,
)


class DummyTokenizer(PreTrainedTokenizer):
    """A dummy tokenizer for testing that produces constant output."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, text: list[str], **kwargs):
        num_sequences = len(text)
        num_tokens = len(text[0])
        encoding = BatchEncoding(
            {
                "input_ids": torch.ones(
                    (num_sequences, num_tokens), dtype=torch.long
                ),
                "attention_mask": torch.ones(
                    (num_sequences, num_tokens), dtype=torch.long
                ),
            }
        )
        if self.bos_token is not None:
            encoding["input_ids"] = torch.cat(
                [
                    torch.tensor(
                        [[self.convert_tokens_to_ids(self.bos_token)]]
                    ),
                    encoding.input_ids,
                ],
                dim=1,
            )
            encoding["attention_mask"] = torch.cat(
                [
                    torch.tensor([[1]]),
                    encoding.attention_mask,
                ],
                dim=1,
            )
        return encoding

    def _convert_id_to_token(self, index: int) -> str:
        return "a"

    def get_vocab(self) -> Dict[str, int]:
        return {char: i for i, char in enumerate("abcdefghijklmnopqrstuvwxyz")}


def test_split_on_placeholders():
    sentence = "The quick <x> jumps over the lazy <y>."
    split_sentence = _split_on_placeholders(sentence)
    assert split_sentence == [
        "The quick",
        " <x>",
        " jumps over the lazy",
        " <y>",
        ".",
    ]


def test_split_with_beginning_placeholder():
    sentence = "<x> recently met <y> at the park."
    split_sentence = _split_on_placeholders(sentence)
    assert split_sentence == ["<x>", " recently met", " <y>", " at the park."]


def test_split_obj_before_sub():
    sentence = "Then, <y> noticed <x> on the other side."
    split_sentence = _split_on_placeholders(sentence)
    assert split_sentence == [
        "Then,",
        " <y>",
        " noticed",
        " <x>",
        " on the other side.",
    ]


def test_insert_names():
    sentence_template = "<x> and <y> are friends."
    name_1 = "Alice"
    name_2 = "Bob"
    tokenizer = DummyTokenizer()
    tokenizer_type = "pythia"

    (
        sentence,
        input_ids,
        attention_mask,
        name_mask,
    ) = _insert_names(
        sentence_template,
        name_1,
        name_2,
        tokenizer,
        tokenizer_type,
    )
    assert sentence == "Alice and Bob are friends."
    assert input_ids.shape == torch.Size((len(sentence),))
    assert attention_mask.shape == torch.Size((len(sentence),))
    assert name_mask.shape == torch.Size((len(sentence),))
    assert torch.all(name_mask[0:5] == 1)
    assert torch.all(name_mask[5:9] == 0)
    assert torch.all(name_mask[9:13] == 1)
    assert torch.all(name_mask[13 : len(sentence)] == 0)


def test_insert_names_with_bos_token():
    sentence_template = "<x> and <y> are friends."
    name_1 = "Alice"
    name_2 = "Bob"
    tokenizer = DummyTokenizer()
    tokenizer.add_special_tokens({"bos_token": "<box>"})
    tokenizer_type = "llama2"

    (
        sentence,
        input_ids,
        attention_mask,
        name_mask,
    ) = _insert_names(
        sentence_template,
        name_1,
        name_2,
        tokenizer,
        tokenizer_type,
    )
    assert sentence == "Alice and Bob are friends."
    assert input_ids.shape == torch.Size((len(sentence) + 1,))
    assert attention_mask.shape == torch.Size((len(sentence) + 1,))
    assert name_mask.shape == torch.Size((len(sentence) + 1,))
    assert name_mask[0] == 0
    assert torch.all(name_mask[1:6] == 1)
    assert torch.all(name_mask[6:10] == 0)
    assert torch.all(name_mask[10:14] == 1)
    assert torch.all(name_mask[14 : len(sentence) + 1] == 0)
