import logging
from functools import partial

from datasets import DatasetDict
from transformers import BatchEncoding, PreTrainedTokenizerBase


logger = logging.getLogger(__name__)


def preprocess_dataset(
    dataset: DatasetDict,
    tokenizer: PreTrainedTokenizerBase,
    # max_length: int,
    seed: int | None = None,
) -> DatasetDict:
    """Loads the training dataset and tokenizes it so it is ready for training.

    Args:
        tokenizer (AutoTokenizer): Tokenizer tied to the model.
        max_length (int): Maximum number of tokens to emit from tokenizer.

    Returns:
        Dataset: HuggingFace dataset
    """
    source_dataset_columns = list(dataset["train"].features.keys())
    _preprocessing_function = partial(
        # preprocess_batch, max_length=max_length, tokenizer=tokenizer
        _preprocess_batch,
        tokenizer=tokenizer,
    )
    dataset = dataset.map(
        _preprocessing_function,
        batched=True,
        remove_columns=source_dataset_columns,
    )

    # Make sure we don't have any truncated records, as this would mean
    # the end keyword is missing.
    logger.info(f"Processed dataset has {dataset.num_rows} rows")
    # dataset = dataset.filter(lambda rec: len(rec["input_ids"]) < max_length)
    # logger.info(f"Processed dataset has {dataset.num_rows} rows after
    # filtering for truncated records")

    logger.info("Shuffling dataset")
    if seed is not None:
        dataset = dataset.shuffle(seed=seed)
    dataset.set_format("torch")

    logger.info("Done preprocessing")

    return dataset


def _preprocess_batch(
    batch: dict[str, list],
    tokenizer: PreTrainedTokenizerBase,
    # max_length: int,
) -> BatchEncoding:
    return tokenizer(
        batch["text"],
        # max_length=max_length,
        truncation=True,
    )


# from lib_dl.data.loading import (
#     DataLoader,
#     DataLoaderConfig,
#     load,
# )
# from lib_dl.data.lib.dataset_accessor import (
#     LightDatasetAccessor,
#     DatasetStage,
# )

# @dataclass
# class LMDataConfig:
#     loader: DataLoaderConfig
#     max_sequence_length: int
#     seed: int

# class LMDataModule(LightDatasetAccessor):
#     """An adapter to convert Huggingface LM datasets to Pytorch Lightning
#     datamodules."""

#     def __init__(
#         self,
#         data: DatasetDict,
#         tokenizer: AutoTokenizer,
#         config: LMDataConfig,
#     ) -> None:
#         super().__init__()
#         self.data = data
#         self.tokenizer = tokenizer
#         self.config = config

#     def setup(self, stage: DatasetStage | None = None) -> None:
#         logger.info("Preprocessing dataset")
#         for stage in LightDatasetAccessor.get_relevant_stages(stage):
#             if stage == "val":
#                 lookup_stage = "validation"
#             else:
#                 lookup_stage = stage
#             stage_data = self.data[lookup_stage]
#             columns_to_keep = ["input_ids", "attention_mask"]
#             columns_to_remove = [
#                 col for col in stage_data.features.keys()
#                 if col not in columns_to_keep
#             ]
#             preprocessed_data = self.data[lookup_stage].map(
#                 partial(
#                     self.tokenize_function,
#                     tokenizer=self.tokenizer,
#                     max_length=self.config.max_sequence_length,
#                 ),
#                 batched=True,
#                 batch_size=self.config.loader.batch_size,
#                 remove_columns=columns_to_remove,
#             )
#             preprocessed_data.set_format(
#                 type="torch",
#                 columns=columns_to_keep,
#             )
#             preprocessed_data = preprocessed_data.filter(
#                 lambda rec: (
# len(rec["input_ids"]) < self.config.max_sequence_length
# )
#             )
#             if stage == "train":
#                 preprocessed_data = preprocessed_data.shuffle(
#                     seed=self.config.seed
#                 )
#             self.set_dataset(stage, preprocessed_data)

#     @staticmethod
#     def preprocess_batch(
#         batch: dict[str, list],
#         tokenizer: AutoTokenizer,
#         max_length: int,
#     ) -> dict:
#         return tokenizer(
#             batch["text"],
#             max_length=max_length,
#             truncation=True,
#         )

#     def train_dataloader(self) -> DataLoader:
#         dataset = self.get_dataset("train")
#         return load(
#             dataset,
#             train=True,
#             config=self.config.loader,
#         )
