import logging
from dataclasses import dataclass
from pathlib import Path

import numpy as np
from transformers import PreTrainedTokenizer

from .utils import (
    SyntheticStringConfig,
    SyntheticStringData,
    get_character_probabilities,
    sample_tokens,
)


SCRIPT_PATH = Path(__file__).parent.resolve()
logger = logging.getLogger(__name__)


@dataclass
class RandomStringConfig(SyntheticStringConfig):
    num_tokens: int
    # The number of equi-sized splits that the data should be divided into.
    # This results in multiple shorter, instead of one long sequence.
    num_partitions: int = 1
    # The probability of the first character in the alphabet.
    # None means uniform distribution, i.e. maximum entropy.
    # Can be used to control the level of entropy of the distribution
    entropy_like: int | None = None
    first_char_prob: float | None = None
    constant: bool = False

    def __post_init__(self) -> None:
        super().__post_init__()

        if self.entropy_like is not None and self.constant:
            raise ValueError(
                "Cannot set entropy_like and constant at the same time"
            )
        entropy_levels = {
            13: 0.41385763230705264,
            7: 0.6040332620868685,
            4: 0.7455315447635649,
            2: 0.8913995544185638,
        }
        if self.entropy_like is not None:
            assert self.alphabet_size == 26
            self.first_char_prob = entropy_levels[self.entropy_like]

    @property
    def alphabet_id(self) -> str:
        alphabet_id = super().alphabet_id
        if self.entropy_like is not None:
            alphabet_id += f"_h-{self.entropy_like}"
        if self.constant:
            alphabet_id += "_const"
        return alphabet_id

    @property
    def name(self) -> str:
        return (
            "rand_"
            + self.alphabet_id
            + f"_t-{self.num_tokens}"
            + (f"_p-{self.num_partitions}" if self.num_partitions > 1 else "")
        )

    @property
    def character_probabilities(self) -> np.ndarray:
        return np.array(
            get_character_probabilities(
                self.alphabet_size, self.first_char_prob
            ),
            dtype=np.float32,
        )

    @property
    def guess_accuracy(self) -> float:
        return np.max(self.character_probabilities).item()

    @property
    def guess_ce_loss(self) -> float:
        return -np.sum(
            self.character_probabilities * np.log(self.character_probabilities)
        ).item()

    @property
    def uniform_entropy(self) -> float:
        if self.entropy_like is None:
            return np.log2(self.alphabet_size).item()
        else:
            assert self.alphabet_size == 26
            return np.log2(self.entropy_like).item()


@dataclass
class RandomStringData(SyntheticStringData):
    config: RandomStringConfig


def generate_random_strings(
    config: RandomStringConfig,
    tokenizer: PreTrainedTokenizer,
) -> RandomStringData:
    assert config.num_tokens % config.num_partitions == 0
    sample_result = sample_tokens(
        config,
        tokenizer=tokenizer,
        string_length=config.num_tokens // config.num_partitions,
        num_strings=config.num_partitions,
        first_char_prob=config.first_char_prob,
        constant=config.constant,
    )
    return RandomStringData(
        config=config,
        raw_token_ids=sample_result.token_ids,
        alphabet_token_ids=sample_result.alphabet_token_ids,
        start_of_string_token_id=sample_result.start_of_string_token_id,
        **SyntheticStringData.ids_to_tokens(
            tokenizer=tokenizer,
            token_ids=sample_result.token_ids,
            alphabet_token_ids=sample_result.alphabet_token_ids,
            start_of_string_token_id=sample_result.start_of_string_token_id,
        ),
    )
