from dataclasses import dataclass
from itertools import product

import numpy as np
from transformers import PreTrainedTokenizer

from .random import RandomStringConfig, RandomStringData
from .utils import (
    SyntheticStringData,
    get_character_probabilities,
    sample_tokens,
)


@dataclass
class ConditionalRandomStringConfig(RandomStringConfig):
    # The length of the prefix that is needed to determine
    # the conditional probability of the next token
    ngram_length: int = 0

    @property
    def name(self) -> str:
        base_name = super().name
        return f"{base_name}_ngl-{self.ngram_length}"

    def set_relative_probability(self, relative_prob: float | int) -> None:
        # The first character probability specifies the probability of the
        # priviledged token for each ngram.
        # It is args.relative_probability times more likely than the other
        # characters.
        self.first_char_prob = relative_prob / (
            self.alphabet_size + relative_prob - 1
        )


def create_ngram_conditional_dataset(
    data_config: ConditionalRandomStringConfig,
    tokenizer: PreTrainedTokenizer,
) -> RandomStringData:
    """Create a dataset where the probability of the next token is
    conditioned on the previous `ngram_length` tokens.
    We keep the unconditional probabilities of the characters the same,
    but change the conditional probability distribution.
    """
    assert data_config.num_partitions == 1
    assert data_config.ngram_length > 0
    # assert data_config.first_char_prob is not None

    # Sample the initial tokens uniformly
    string_start = sample_tokens(
        data_config,
        tokenizer,
        string_length=data_config.ngram_length,
        num_strings=1,
    )
    alphabet_token_ids = string_start.alphabet_token_ids
    token_ids = [int(tid) for tid in string_start.token_ids[0]]

    seed = data_config.seed_value
    assert seed is not None
    rng = np.random.default_rng(seed)
    ngram_mapping = _sample_ngram_conditional_token_indices(
        data_config.ngram_length,
        alphabet_token_ids,
        rng,
    )
    # We use the first char probs to create the probability distribution
    # for characters based on the previous `ngram_length` characters
    next_token_probs = get_character_probabilities(
        data_config.alphabet_size,
        data_config.first_char_prob,
    )
    first_char_prob = next_token_probs[0]
    remaining_char_probs = next_token_probs[1]
    print("first_char_prob", first_char_prob)
    print("remaining_char_probs", remaining_char_probs)

    for _ in range(data_config.num_tokens - data_config.ngram_length):
        ngram_prefix = tuple(token_ids[-data_config.ngram_length :])
        privileged_token_idx = ngram_mapping[ngram_prefix]
        next_token_probs = np.full(
            data_config.alphabet_size, remaining_char_probs
        )
        next_token_probs[privileged_token_idx] = first_char_prob
        next_token = rng.choice(
            alphabet_token_ids,
            size=1,
            p=next_token_probs,
        )
        token_ids.append(int(next_token))

    token_ids = np.array(token_ids).reshape(1, -1)
    return RandomStringData(
        config=data_config,
        raw_token_ids=token_ids,
        alphabet_token_ids=alphabet_token_ids,
        start_of_string_token_id=string_start.start_of_string_token_id,
        **SyntheticStringData.ids_to_tokens(
            tokenizer=tokenizer,
            token_ids=token_ids,
            alphabet_token_ids=alphabet_token_ids,
            start_of_string_token_id=string_start.start_of_string_token_id,
        ),
    )


def _sample_ngram_conditional_token_indices(
    ngram_length: int,
    alphabet_token_ids: np.ndarray,
    rng: np.random.Generator,
) -> dict[tuple[int, ...], int]:
    """For each ngram, sample one token index that should appear with higher
    probability following it than the others.
    """
    assert ngram_length > 0
    alphabet_size = len(alphabet_token_ids)
    num_combinations = alphabet_size**ngram_length
    target_token_indices = rng.permutation(num_combinations) % alphabet_size
    _, counts = np.unique(target_token_indices, return_counts=True)
    assert np.all(
        counts == counts[0]
    ), "All target token indices should appear equally often"

    ngram_mapping = {}
    for ngram_tokens, target_token_idx in zip(
        product(alphabet_token_ids, repeat=ngram_length),
        target_token_indices,
    ):
        ngram = tuple(int(tid) for tid in ngram_tokens)
        ngram_mapping[ngram] = int(target_token_idx)
    return ngram_mapping
