import logging
from dataclasses import dataclass
from pathlib import Path

import numpy as np
from datasets import Dataset, DatasetDict, load_dataset
from transformers import BatchEncoding, PreTrainedTokenizer


SCRIPT_PATH = Path(__file__).parent.resolve()
logger = logging.getLogger(__name__)


@dataclass
class TextGenerationDataConfig:
    dataset: str
    variant: str
    seed: int
    sequence_length: int
    num_sequences: int = -1


NUM_PROCESSES = 8


def load_text_generation_data(
    config: TextGenerationDataConfig,
    tokenizer: PreTrainedTokenizer,
) -> DatasetDict:
    # raw_datasets = load_dataset("wikitext", "wikitext-2-v1")
    raw_datasets = load_dataset(config.dataset, name=config.variant)
    assert isinstance(raw_datasets, DatasetDict)
    # assert isinstance(raw_datasets, Dataset)

    sequence_length = config.sequence_length
    # tokenizer.pad_token = tokenizer.eos_token
    # tokenizer.padding_side = "right"
    # tokenizer.model_max_length = sequence_length

    print("Tokenizing dataset...")

    def preprocess_function(examples: dict) -> BatchEncoding:
        return tokenizer(
            # [" ".join(x) for x in examples["text"]], truncation=True
            examples["text"],
            # truncation=False,
            truncation=True,
            # padding="max_length",
            # "\n\n".join(examples["text"]), truncation=True,
        )

    tokenized_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        num_proc=NUM_PROCESSES,
        remove_columns=["text"],
    )

    # logger.info("Slicing texts...")
    # # Split the text into blocks of length sequence_length
    # def slice_texts(examples: dict) -> dict:
    #     # Concatenate all texts.
    #     concatenated_texts = "".join(examples["text"])
    #     total_length = len(concatenated_texts)
    #     if total_length % sequence_length != 0:
    #         # Drop the last part to make sure we don't overrun
    #         # the length.
    #         total_length -= sequence_length
    #     return {
    #         "text": [
    #             concatenated_texts[i : i + sequence_length]
    #             for i in range(0, total_length, sequence_length)
    #         ]
    #     }
    # sliced_datasets = tokenized_datasets.map(
    #     slice_texts,
    #     batched=True,
    #     num_proc=4,
    # )
    # sliced_datasets = tokenized_datasets

    print("Grouping texts...")

    # Split the text into blocks of length block_size
    def group_texts(examples: dict) -> dict:
        # Concatenate all texts.
        concatenated_examples = {
            k: sum(examples[k], []) for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model
        # supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= sequence_length:
            total_length = (total_length // sequence_length) * sequence_length
        result = {
            k: [
                t[i : i + sequence_length]
                for i in range(0, total_length, sequence_length)
            ]
            for k, t in concatenated_examples.items()
        }
        # Duplicate the text as labels for language modeling.
        # result["labels"] = result["input_ids"].copy()
        return result

    grouped_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=NUM_PROCESSES,
    )

    logger.info(f"Extracting {config.num_sequences} sequences...")
    training_data = grouped_datasets["train"]
    rng = np.random.default_rng(config.seed)
    shuffled_indices = rng.permutation(len(training_data))
    if config.num_sequences == -1:
        training_indices = shuffled_indices
    else:
        training_indices = shuffled_indices[: config.num_sequences]
    shuffled_training_data = training_data.select(training_indices)

    datasets = DatasetDict(
        {
            "train": shuffled_training_data,
            "validation": grouped_datasets["validation"],
            "test": grouped_datasets["test"],
        }
    )
    datasets.set_format("torch")

    return datasets


def add_token_string_reps(
    dataset: Dataset | DatasetDict,
    tokenizer: PreTrainedTokenizer,
) -> Dataset | DatasetDict:
    """For the input ids in the dataset, add a list showing the string
    representation of each token, and a count for the number of tokens."""

    def reverse_tokenize(examples: dict) -> dict:
        return {
            "tokens": [
                tokenizer.convert_ids_to_tokens(x)
                for x in examples["input_ids"]
            ],
            "text": [tokenizer.decode(x) for x in examples["input_ids"]],
            # "num_tokens": sum(examples["attention_mask"]),
        }

    return dataset.map(
        reverse_tokenize,
        batched=True,
        num_proc=NUM_PROCESSES,
    )
