from typing import Callable

import numpy as np
import torch
from transformers import PreTrainedTokenizer

from data.synthetic_strings.random import (
    RandomStringConfig,
    generate_random_strings,
)
from lib_llm.eval.memorization.prefix_mappings.eval import PrefixEvalConfig


def setup_replacements(
    config: PrefixEvalConfig,
    replacement_strategy: str,
    tokenizer: PreTrainedTokenizer,
    tokenizer_type: str,
    replacement_length: int,
) -> Callable[[torch.Tensor, int], torch.Tensor]:
    if replacement_strategy not in [
        "rand_id",
        "const_id",
        "rand_ood",
        "const_ood",
    ]:
        raise ValueError(
            f"Unknown replacement strategy: {replacement_strategy}"
        )
    rng = np.random.default_rng(config.seed)
    if replacement_strategy.endswith("_id"):
        alphabet = "latin"
    else:  # OOD
        alphabet = "non_latin"
    constant = replacement_strategy.startswith("const")

    if config.relative_context_size > 1:
        additional_replacement_size = int(
            np.ceil((config.relative_context_size - 1) * replacement_length)
        )
        replacement_length += additional_replacement_size
    random_string_config = RandomStringConfig(
        seed=rng.integers(0, 2**32 - 1),
        alphabet=alphabet,
        num_partitions=config.num_samples_per_prefix,
        constant=constant,
        num_tokens=replacement_length * config.num_samples_per_prefix,
        tokenizer_type=tokenizer_type,
    )
    replacements = torch.from_numpy(
        generate_random_strings(
            random_string_config,
            tokenizer,
        ).token_ids
    )

    def get_replacements(_: torch.Tensor, target_length: int) -> torch.Tensor:
        return replacements[:, :target_length]

    return get_replacements
