import logging
import string
from dataclasses import dataclass
from pathlib import Path

import numpy as np
from datasets import Dataset, DatasetDict, load_dataset


SCRIPT_PATH = Path(__file__).parent.resolve()
logger = logging.getLogger(__name__)


@dataclass
class RandomStringConfig:
    data_type: str
    seed: int
    characterwise_tokenization: bool
    num_sequences: int
    sequence_length: int
    alphabet: str = "latin"
    alphabet_size: int = 26


@dataclass
class GeneratedDataset:
    data: DatasetDict
    alphabet_characters: str


def generate_random_strings(config: RandomStringConfig, shifted_pos_token_experiment: bool, dataset_1024: bool = False,) -> GeneratedDataset:
    rng = np.random.default_rng(config.seed)
    # if config.num_sequences * config.sequence_length != 1024:
    #     raise ValueError(
    #         f"num_sequences * sequence_length must be 1024, but is "
    #         f"{config.num_sequences * config.sequence_length}"
    #         f"(num_sequences={config.num_sequences}, "
    #         f"sequence_length={config.sequence_length})"
    #     )
    sequences, alphabet_characters = generate_strings(
        config.num_sequences,
        config.sequence_length,
        alphabet=config.alphabet,
        alphabet_size=config.alphabet_size,
        rng=rng,
        shifted_pos_token_experiment=shifted_pos_token_experiment,
        dataset_1024=dataset_1024,
    )
    dataset = Dataset.from_dict(
        {
            "text": sequences,
        }
    )
    datasets = DatasetDict(
        {
            "train": dataset,
            "test": dataset,
        }
    )
    datasets.set_format("torch")
    return GeneratedDataset(
        data=datasets,
        alphabet_characters=alphabet_characters,
    )


def generate_strings(
    num_sequences: int,
    sequence_length: int,
    alphabet: str,
    alphabet_size: int,
    rng: np.random.Generator,
    shifted_pos_token_experiment: bool,
    dataset_1024: bool,
) -> tuple[list[str], str]:
    if dataset_1024:
        num_sequences = 1
        sequence_length = 1024
    if alphabet == "latin":
        alphabet_characters = string.ascii_lowercase
    else:
        raise ValueError(f"Unknown alphabet {alphabet}")
    if alphabet_size > len(alphabet_characters):
        raise ValueError(
            f"Alphabet size {alphabet_size} is larger than the number "
            f"of characters in the alphabet {len(alphabet_characters)}"
        )
    alphabet_characters = alphabet_characters[:alphabet_size]
    if shifted_pos_token_experiment:
        seq_gen = num_sequences
    else:
        seq_gen = 1
    sequences = rng.choice(
        list(alphabet_characters),
        size=(1, sequence_length*num_sequences),
    )
    if shifted_pos_token_experiment:
        sequences = [sequences[0] for _ in range(seq_gen)]

    # Reshape the array to a list of sequences.
    if not shifted_pos_token_experiment:
        reshaped_sequences = []
        for i in range(num_sequences):
            reshaped_sequences.append(
                sequences[0, i*sequence_length:(i+1)*sequence_length]
            )

        sequences = reshaped_sequences

    returningObj = (
        ["".join(seq) for seq in sequences],
        alphabet_characters,
    )

    for seq in returningObj[0]:
        print(seq)
        print()

    return returningObj


def load_wiki_text(
    config: RandomStringConfig,
) -> DatasetDict:
    raw_datasets = load_dataset("wikitext", "wikitext-2-v1")
    assert isinstance(raw_datasets, DatasetDict)
    # assert isinstance(raw_datasets, Dataset)

    sequence_length = config.sequence_length

    # Split the text into blocks of length sequence_length
    def slice_texts(examples: dict) -> dict:
        # Concatenate all texts.
        concatenated_texts = "".join(examples["text"])
        total_length = len(concatenated_texts)
        if total_length % sequence_length != 0:
            # Drop the last part to make sure we don't overrun
            # the length.
            total_length -= sequence_length
        return {
            "text": [
                concatenated_texts[i : i + sequence_length]
                for i in range(0, total_length, sequence_length)
            ]
        }

    logger.info("Slicing texts...")
    sliced_datasets = raw_datasets.map(
        slice_texts,
        batched=True,
        # num_proc=4,
    )

    logger.info(f"Extracting {config.num_sequences} sequences...")
    training_data = sliced_datasets["train"]
    rng = np.random.default_rng(config.seed)
    shuffled_indices = rng.permutation(len(training_data))
    training_indices = shuffled_indices[: config.num_sequences]
    shuffled_training_data = training_data.select(training_indices)
    datasets = DatasetDict(
        {
            "train": shuffled_training_data,
            "test": shuffled_training_data,
        }
    )
    datasets.set_format("torch")

    return datasets
