from typing import TypedDict

from transformers import PreTrainedTokenizer

from .deterministic_rules import (
    DeterministicRuleStringConfig,
    DeterministicRuleStringData,
    generate_deterministic_rule_strings,
)
from .random import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from .utils import uses_bos_token


def get_config(
    random_config: RandomStringConfig | None,
    deterministic_rule_config: DeterministicRuleStringConfig | None,
) -> RandomStringConfig | DeterministicRuleStringConfig:
    if random_config is not None and deterministic_rule_config is not None:
        raise ValueError("Only config should be specified")
    if random_config is not None:
        return random_config
    elif deterministic_rule_config is not None:
        return deterministic_rule_config
    else:
        raise ValueError("No data config specified")


def load_data(
    config: RandomStringConfig | DeterministicRuleStringConfig,
    tokenizer: PreTrainedTokenizer,
) -> RandomStringData | DeterministicRuleStringData:
    if isinstance(config, RandomStringConfig):
        return generate_random_strings(config, tokenizer)
    elif isinstance(config, DeterministicRuleStringConfig):
        return generate_deterministic_rule_strings(config, tokenizer)
    else:
        raise ValueError(f"Unknown data config type {type(config)}")


class DeunionedData(TypedDict):
    random_data: RandomStringData | None
    deterministic_rule_data: DeterministicRuleStringData | None


def deunion_data(
    data: RandomStringData | DeterministicRuleStringData,
) -> DeunionedData:
    if isinstance(data, RandomStringData):
        return {"random_data": data, "deterministic_rule_data": None}
    elif isinstance(data, DeterministicRuleStringData):
        return {"random_data": None, "deterministic_rule_data": data}
    else:
        raise ValueError(f"Unknown data type {type(data)}")
