from pathlib import Path
from typing import Union, Optional, List, Dict
import time

import numpy as np
from torch import Tensor
import torch
from datasets import load_dataset
from transformers import PreTrainedTokenizerBase


def build_dataset(
    tokenizer: PreTrainedTokenizerBase,
    data_dir: Union[str, Path] = "/home/fineweb-edu/sample/100BT",
    streaming: bool = True,
    n_workers: int = 8,
    overwrite_cache: bool = False,
    token_ids_only: bool = True,
    max_len: int = 512,
    eos_token_id: Optional[int] = None,
    shift: bool = True,
):
    """
    Returns an iterable of batches of token IDs.

    This will use `load_dataset` from the HuggingFace Datasets library to load the
    data from `data_dir`, tokenize each example, concatenate the input IDs, add an
    EOS token ID at the end of each sequence, then split into chunks of `max_len`
    tokens, and return a tensor of (batch_size, max_len).
    """
    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id

    assert eos_token_id is not None

    print(f"Loading data from {data_dir}")
    raw_dataset = load_dataset(
        data_dir,
        split='train',
        streaming=streaming,
    )
    print("Dataset loaded")

    text_column_name = "text"
    col_names = [
        "text",
        "id",
        "dump",
        "url",
        "file_path",
        "language",
        "language_score",
        "token_count",
        "score",
        "int_score",
    ]

    # Tokenize in streaming mode
    def process_fn(examples: dict) -> Dict[str, Tensor]:
        """
        A process function to use with `Dataset.map`. It tokenizes
        texts in the batch, concatenate them, and split into chunks
        with `max_len` tokens (discarding the last chunk if
        incomplete).
        """
        texts: List[str] = examples[text_column_name]
        encodings = tokenizer(
            texts, max_length=10**6, truncation=True, return_tensors="np"
        )

        # Append EOS token
        orig_input_ids: np.ndarray = encodings["input_ids"]
        batch_ids = [np.append(ids, eos_token_id) for ids in orig_input_ids]
        concat_ids = np.concatenate(batch_ids)
        total_len = concat_ids.shape[0]
        if shift:
            chunk_len = max_len + 1  # The input IDs with be chunk_len - 1
        else:
            chunk_len = max_len

        # Rounded down to multiple of chunk_len.
        # So the last remainder chunk is discarded.
        total_len = total_len // chunk_len * chunk_len
        n_chunks = total_len // chunk_len

        if shift:
            input_ids = np.zeros((n_chunks, chunk_len - 1), dtype=np.int64)
            labels = np.zeros((n_chunks, chunk_len - 1), dtype=np.int64)
        else:
            input_ids = np.zeros((n_chunks, chunk_len), dtype=np.int64)
            labels = np.zeros((n_chunks, chunk_len), dtype=np.int64)

        for i in range(n_chunks):
            this_chunk = concat_ids[i * chunk_len : (i + 1) * chunk_len]
            if shift:
                # Next token prediction with teacher forcing
                input_ids[i] = this_chunk[:-1]
                labels[i] = this_chunk[1:]
            else:
                input_ids[i] = this_chunk
                labels[i] = this_chunk

        input_ids: Tensor = torch.from_numpy(input_ids)
        labels: Tensor = torch.from_numpy(labels)

        batch = {
            "input_ids": input_ids,
            "labels": labels,
        }
        return batch

    if streaming:
        # Streaming dataset does not support multi-processing yet.
        tokenized_dataset = raw_dataset.map(
            process_fn,
            batched=True,
            remove_columns=col_names if token_ids_only else [],
        )
    else:
        tokenized_dataset = raw_dataset.map(
            process_fn,
            batched=True,
            num_proc=n_workers,
            remove_columns=col_names if token_ids_only else [],
            load_from_cache_file=not overwrite_cache,
            desc="Running tokenizer on dataset",
        )
    return tokenized_dataset