import string
from dataclasses import dataclass
from typing import Optional, TypedDict, TypeVar

import numpy as np
import torch
from datasets import Dataset, DatasetDict
from transformers import BatchEncoding, PreTrainedTokenizer


SEEDS = [
    5932,
    2575,
    5740,
    6927,
    2080,
    1232,
    9094,
    9971,
    3955,
    6801,
    2854,
    2854,
]


@dataclass(kw_only=True)
class SyntheticStringConfig:
    tokenizer_type: str
    seed_id: int | None = None
    seed: int | None = None
    alphabet: str = "latin"
    alphabet_size: int = 26

    def __post_init__(self) -> None:
        if self.seed_id is None and self.seed is None:
            raise ValueError("Either seed_id or seed must be given")
        if self.seed_id is not None and self.seed is not None:
            raise ValueError("Only one of seed_id or seed must be given")

    @property
    def seed_value(self) -> int:
        if self.seed is not None:
            return self.seed
        elif self.seed_id is not None:
            return SEEDS[self.seed_id]
        else:
            raise ValueError("Either seed_id or seed must be given")

    @property
    def alphabet_abbrv(self) -> str:
        if self.alphabet == "latin":
            return "lat"
        elif self.alphabet == "non_latin":
            return "nlat"
        elif self.alphabet == "numeric":
            return "num"
        else:
            raise ValueError(f"Invalid alphabet: {self.alphabet}")

    @property
    def alphabet_id(self) -> str:
        return f"a-{self.alphabet_abbrv}-{self.alphabet_size}"

    @property
    def name(self) -> str:
        raise NotImplementedError
        # return f"sid-{self.seed_id}_{self.alphabet_id}"

    @property
    def sid(self) -> str:
        return f"sid-{self.seed_id}"


@dataclass(kw_only=True)
class SyntheticStringData:
    config: SyntheticStringConfig
    raw_token_ids: np.ndarray  # n x m
    raw_tokens: list[list[str]]
    alphabet_token_ids: np.ndarray  # a
    alphabet_tokens: list[str]
    raw_attention_mask: np.ndarray | None = None  # n x m
    start_of_string_token_id: int | None = None
    start_of_string_token: str | None = None

    @property
    def token_ids(self) -> np.ndarray:
        token_ids = self.raw_token_ids
        if self.start_of_string_token_id is not None:
            # Prepend the start of string token to each partition
            start_tokens = np.full(
                (len(token_ids), 1),
                self.start_of_string_token_id,
                dtype=np.int32,
            )
            return np.concatenate((start_tokens, token_ids), axis=1)
        return token_ids

    @property
    def tokens(self) -> list[list[str]]:
        tokens = self.raw_tokens
        if self.start_of_string_token is not None:
            # Prepend the start of string token to each partition
            return [
                [self.start_of_string_token] + string_tokens
                for string_tokens in tokens
            ]
        return tokens

    @property
    def attention_mask(self) -> np.ndarray:
        if self.raw_attention_mask is None:
            return np.ones_like(self.token_ids)
        attention_mask = self.raw_attention_mask
        if self.start_of_string_token_id is not None:
            # Prepend the start of string token to each partition
            start_tokens = np.ones(
                (len(attention_mask), 1),
                dtype=np.int32,
            )
            return np.concatenate((start_tokens, attention_mask), axis=1)
        return attention_mask

    def dataset(self) -> DatasetDict:
        dataset = Dataset.from_dict(
            {
                **self._encoding(),
                "text": ["".join(tokens) for tokens in self.tokens],
                "tokens": self.tokens,
            }
        )
        datasets = DatasetDict(
            {
                "train": dataset,
                "test": dataset,
            }
        )
        datasets.set_format("torch")
        return datasets

    def batch_encoding(self) -> BatchEncoding:
        return BatchEncoding(self._encoding())

    def _encoding(self) -> dict[str, torch.Tensor]:
        return {
            "input_ids": torch.from_numpy(self.token_ids).long(),
            "attention_mask": torch.from_numpy(self.attention_mask).long(),
        }

    class TokenIdsAsStrings(TypedDict):
        raw_tokens: list[list[str]]
        alphabet_tokens: list[str]
        start_of_string_token: str | None

    @staticmethod
    def ids_to_tokens(
        tokenizer: PreTrainedTokenizer,
        token_ids: np.ndarray,
        alphabet_token_ids: np.ndarray,
        start_of_string_token_id: int | None = None,
    ) -> TokenIdsAsStrings:
        tokens = [
            tokenizer.convert_ids_to_tokens(string_token_ids.tolist())
            for string_token_ids in token_ids
        ]
        alphabet_tokens = tokenizer.convert_ids_to_tokens(
            alphabet_token_ids.tolist()
        )
        start_of_string_token = (
            tokenizer.convert_ids_to_tokens([start_of_string_token_id])[0]
            if start_of_string_token_id is not None
            else None
        )
        return {
            "raw_tokens": tokens,
            "alphabet_tokens": alphabet_tokens,
            "start_of_string_token": start_of_string_token,
        }


@dataclass
class SampleResult:
    token_ids: np.ndarray  # n x m
    alphabet_token_ids: np.ndarray  # a
    start_of_string_token_id: int | None


def sample_tokens(
    config: SyntheticStringConfig,
    tokenizer: PreTrainedTokenizer,
    string_length: int,
    num_strings: int,
    first_char_prob: Optional[float] = None,
    constant: bool = False,
) -> SampleResult:
    alphabet_token_ids, start_of_string_token_id = _get_alphabet_token_ids(
        tokenizer=tokenizer,
        tokenizer_type=config.tokenizer_type,
        alphabet=config.alphabet,
        alphabet_size=config.alphabet_size,
        alphabet_seed=config.seed_value,
    )

    rng = np.random.default_rng(config.seed_value)
    char_probabilities = get_character_probabilities(
        num_alphabet_chars=len(alphabet_token_ids),
        first_char_prob=first_char_prob,
    )

    if constant:
        sampled_token_ids = rng.choice(
            alphabet_token_ids,
            size=(num_strings, 1),
            p=char_probabilities,
        )
        string_token_ids = np.stack(
            [
                np.full((string_length,), token_sample, dtype=np.int32)
                for token_sample in sampled_token_ids
            ]
        )
    else:
        string_token_ids = rng.choice(
            alphabet_token_ids,
            size=(num_strings, string_length),
            p=char_probabilities,
        )

    return SampleResult(
        string_token_ids,
        alphabet_token_ids,
        start_of_string_token_id,
    )


def _get_alphabet_token_ids(
    tokenizer: PreTrainedTokenizer,
    tokenizer_type: str,
    alphabet: str,
    alphabet_size: int,
    alphabet_seed: int,  # | None = None,
) -> tuple[np.ndarray, int | None]:
    if alphabet == "latin" or alphabet == "non_latin":
        alphabet_characters = string.ascii_lowercase
    elif alphabet == "numeric":
        alphabet_characters = string.digits
    else:
        raise ValueError(f"Invalid alphabet: {alphabet}")
    alphabet_token_ids, start_of_string_token_id = alphabet_encoding(
        tokenizer=tokenizer,
        tokenizer_type=tokenizer_type,
        alphabet_characters=alphabet_characters,
    )
    if alphabet.startswith("non_"):
        # "Invert" the set of token ids
        special_tokens_set = set(tokenizer.all_special_ids)
        all_token_ids_set = (
            set(range(tokenizer.vocab_size)) - special_tokens_set
        )
        alphabet_token_ids = np.array(
            list(all_token_ids_set - set(alphabet_token_ids)),
            dtype=np.int32,
        )
        rng = np.random.default_rng(alphabet_seed)
        rng.shuffle(alphabet_token_ids)

    if alphabet_size > len(alphabet_token_ids):
        raise ValueError(
            f"Alphabet size {alphabet_size} is larger than the number "
            f"of characters in the alphabet {len(alphabet_characters)}"
        )
    return alphabet_token_ids[:alphabet_size], start_of_string_token_id


def alphabet_encoding(
    tokenizer: PreTrainedTokenizer,
    tokenizer_type: str,
    alphabet_characters: str,
) -> tuple[np.ndarray, int | None]:
    encoding = tokenizer(
        list(alphabet_characters),
        return_tensors="np",
    )
    input_ids = encoding.input_ids
    assert len(input_ids) == len(alphabet_characters), (
        f"Encoding length {len(encoding)} does not match "
        f"alphabet length {len(alphabet_characters)}"
    )
    if uses_bos_token(tokenizer_type):
        # Llama2 and OPT insert special start of string tokens before every
        # sequence that we need to account for
        assert all(len(tids) == 2 for tids in input_ids)
        return (
            np.array([tids[1] for tids in input_ids], dtype=np.int32),
            int(input_ids[0][0]),
        )
    else:
        assert all(len(tids) == 1 for tids in input_ids)
        return np.array([tids[0] for tids in input_ids], dtype=np.int32), None


def get_character_probabilities(
    num_alphabet_chars: int,
    first_char_prob: Optional[float],
) -> list[float]:
    if first_char_prob is None:
        # Use a uniform distribution.
        char_probabilities = [1 / num_alphabet_chars] * num_alphabet_chars
    else:
        # Sample the first alphabet character with probability first_char_prob.
        # Sample the remaining characters uniformly with the remaining
        # probability mass.
        assert 0 <= first_char_prob <= 1
        remaining_char_prob = (1 - first_char_prob) / (num_alphabet_chars - 1)
        char_probabilities = [first_char_prob] + [remaining_char_prob] * (
            num_alphabet_chars - 1
        )
    return char_probabilities


def uses_bos_token(
    tokenizer_type: str,
) -> bool:
    if tokenizer_type in ["pythia", "gpt2", "phi"]:
        return False
    elif tokenizer_type in ["llama2", "opt"]:
        return True
    else:
        raise ValueError(f"Invalid tokenizer type: {tokenizer_type}")
