from typing import cast

import numpy as np
from transformers import BatchEncoding, PreTrainedTokenizer

from data.synthetic_strings.deterministic_rules import (
    DeterministicRuleStringConfig,
    DeterministicRuleStringData,
    _sample_strings,
    generate_deterministic_rule_strings,
    sample_rules,
)


def test_sample_rules():
    tokenizer_type = "pythia"
    tokenizer = cast(PreTrainedTokenizer, DummyTokenizer())
    config = DeterministicRuleStringConfig(
        seed_id=0,
        string_length=16,
        num_strings=2,
        premise_length=3,
        tokenizer_type=tokenizer_type,
        alphabet_size=7,
    )

    rules = sample_rules(config, tokenizer)
    alphabet_token_ids = set(rules.alphabet_token_ids)
    assert len(rules.rules) == config.alphabet_size**config.premise_length
    assert all(
        s in alphabet_token_ids
        for premise in rules.rules.keys()
        for s in premise
    )
    assert all(
        conclusion in alphabet_token_ids for conclusion in rules.rules.values()
    )


def test_sample_strings():
    tokenizer_type = "pythia"
    tokenizer = cast(PreTrainedTokenizer, DummyTokenizer())
    config = DeterministicRuleStringConfig(
        seed_id=0,
        string_length=32,
        num_strings=4,
        premise_length=2,
        tokenizer_type=tokenizer_type,
        alphabet_size=4,
    )
    rules = sample_rules(config, tokenizer)
    strings = _sample_strings(config, rules)
    alphabet_token_ids = set(rules.alphabet_token_ids)

    assert len(strings) == config.num_strings
    assert all(len(s) == config.string_length for s in strings)
    assert all(s in alphabet_token_ids for string in strings for s in string)
    assert all(
        s == rules.rules[tuple(string[i : i + config.premise_length])]
        for string in strings
        for i, s in enumerate(string[config.premise_length :])
    )


def test_generate_deterministic_rule_strings():
    tokenizer_type = "pythia"
    tokenizer = cast(PreTrainedTokenizer, DummyTokenizer())
    config = DeterministicRuleStringConfig(
        seed_id=0,
        string_length=32,
        num_strings=4,
        premise_length=2,
        tokenizer_type=tokenizer_type,
        alphabet_size=4,
    )
    generated_strings = generate_deterministic_rule_strings(config, tokenizer)

    assert generated_strings.config == config
    rules = sample_rules(config, tokenizer)
    assert len(generated_strings.rules) == len(rules.rules)
    token_ids = _sample_strings(config, rules)
    assert np.all(generated_strings.raw_token_ids == token_ids)


class DummyTokenizer:
    def __call__(
        self,
        input_strings: list[str],
        return_tensors: str,
    ) -> BatchEncoding:
        token_ids = np.array(
            [[ord(x) for x in input_string] for input_string in input_strings]
        )
        return BatchEncoding(
            data={"input_ids": token_ids},
        )

    def convert_ids_to_tokens(self, token_ids: list[int]) -> list[str]:
        return [chr(x) for x in token_ids]
