"""Train Tokenizer from Tokenized Streaming Text Datasets.

This module provides functionality to train a tokenizer on tokenized streaming text
datasets. It processes specified splits of a dataset, trains a
SentencePieceUnigramTokenizer on the tokenized streaming text dataset, converts it to a
fast tokenizer, and saves the trained tokenizer to the specified output directory. The
tokenizer is also pushed to the Hugging Face Hub.

"""

import json
import os
import time
from argparse import ArgumentParser, Namespace
from collections.abc import Sequence
from dataclasses import asdict
from logging import INFO
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
from uuid import uuid4

import torch
from flwr.common import log
from llmfoundry import StreamingTextDataset
from llmfoundry.data.text_data import build_streams
from llmfoundry.utils.builders import (
    build_tokenizer,
)
from omegaconf import DictConfig, OmegaConf
from streaming import Stream
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast

from repo.clients.llm_config_functions import StreamDict
from repo.dataset.samples_generators import stream_and_untokenize
from repo.dataset.utils import (
    TOKENIZER_FOLDER_NAME,
    build_dataloader,
)


def elaborate_stream_yamls(
    streams_config_file: str,
    dataset_config_file: str,
    split: str,
    max_length: int,
    split_temp_dir: TemporaryDirectory,
) -> tuple[str, dict[str, Any]]:
    """Elaborate stream YAML configurations for a specific dataset split.

    This function loads the dataset and streams configuration files, sets up the
    configuration for the specified split, and returns the stream name and stream
    dictionary for the DataLoader.

    Parameters
    ----------
    streams_config_file : str
        Path to the streams configuration file.
    dataset_config_file : str
        Path to the dataset configuration file.
    split : str
        The dataset split to process (e.g., "train" or "validation").
    max_length : int
        The maximum length of the sequences.
    split_temp_dir : TemporaryDirectory
        A temporary directory for storing intermediate files.

    Returns
    -------
    tuple[str, dict[str, Any]]
        A tuple containing the stream name and the stream dictionary.

    Example
    -------
    >>> from tempfile import TemporaryDirectory
    >>> streams_config_file = "path/to/streams_config.yaml"
    >>> dataset_config_file = "path/to/dataset_config.yaml"
    >>> split = "train"
    >>> max_length = 512
    >>> split_temp_dir = TemporaryDirectory()
    >>> stream_name, stream_dict = elaborate_stream_yamls(
    ...     streams_config_file, dataset_config_file, split, max_length, split_temp_dir
    ... )

    """
    # Load the dataset and streams configuration files
    streams_config = OmegaConf.load(streams_config_file)
    # Set some useful defaults
    # TODO(<Anonymous>): Python here doesn't know yet whether we have a DictConfig or a
    # ListConfig. Adding a ignore to bypass the error two lines below
    dataset_config = OmegaConf.load(dataset_config_file)
    assert isinstance(dataset_config, DictConfig)
    split_config = dataset_config.pop(split, None)  # type: ignore[call-arg]
    split_config.max_seq_len = max_length
    split_config.streams = streams_config
    split_config.root_local = split_temp_dir.name
    split_config.shuffle = False
    split_config.cache_limit = "10gb"
    # Retrieve the streams for the client
    log(INFO, "Retrieving streams for the client: %s", streams_config)
    # NOTE: Taking only the first item in the ListConfig
    selected_streams = streams_config[0].client_streams
    # Set streams dictionary for the train loader
    actual_streams = {
        key: StreamDict(**value) for key, value in selected_streams.items()
    }
    # Get the root path for remote and local data
    root_remote = split_config.pop("root_remote", "")
    root_remote = root_remote + "/" if root_remote else root_remote
    root_local = split_config.pop("root_local", "")
    root_local = root_local + "/" if root_local else root_local
    # Propagate the split and the remote and local paths to each stream
    for stream in actual_streams.values():
        # Set the split, remote, and local paths
        stream.split = split or stream.split
        if root_local:
            stream.local = root_local + stream.local if stream.local else root_local
        if root_remote:
            stream.remote = (
                root_remote + stream.remote if stream.remote else root_remote
            )
        # Remove potential trailing slashes
        stream.local = stream.local.rstrip("/") if stream.local else stream.local
        stream.remote = stream.remote.rstrip("/") if stream.remote else stream.remote
    # Convert the streams to dictionaries
    streams_dict = {
        stream_name: asdict(stream) for stream_name, stream in actual_streams.items()
    }
    # Retrieve only the first element in the dictionary
    return next(iter(streams_dict.items()))


def create_streaming_text_dataloader(
    streams: Sequence[Stream],
    decode_tokenizer: PreTrainedTokenizerBase,
    max_length: int,
    num_workers: int,
    batch_size: int = 512,
) -> DataLoader:
    """Create a DataLoader for a streaming text dataset.

    This function constructs a StreamingTextDataset using the provided stream
    configuration and tokenizer, and then builds a DataLoader for streaming the dataset
    in batches. The DataLoader supports multiprocessing and pre-fetching.

    Parameters
    ----------
    streams : Sequence[Stream]
        The sequence of Streams to build the StreamingTextDataset upon.
    decode_tokenizer : PreTrainedTokenizerBase
        The tokenizer to use for decoding the tokenized samples.
    max_length : int
        The maximum length of the sequences.
    num_workers : int
        The number of worker processes to use for data loading.
    batch_size : int, optional
        The batch size for the DataLoader. Default is 512.

    Returns
    -------
    DataLoader
        A DataLoader for the streaming text dataset.

    Example
    -------
    >>> from transformers import PreTrainedTokenizerFast
    >>> decode_tokenizer = PreTrainedTokenizerFast.from_pretrained("bert-base-uncased")
    >>> stream_name = "example_stream"
    >>> stream_dict = {"remote": "s3://mybucket/dataset", "local": "/path/to/local"}
    >>> dataloader = create_streaming_text_dataloader(
    ...     stream_name, stream_dict, decode_tokenizer, max_length=512, num_workers=4
    ... )

    """
    # Construct the streaming dataset
    streaming_text_dataset = StreamingTextDataset(
        tokenizer=decode_tokenizer,
        max_seq_len=max_length,
        streams=streams,
        batch_size=batch_size,
    )
    # Build a batched dataloader for streaming the HF dataset in batches so that
    # we can actually take advantage of multiprocessing and pre-fetching
    return build_dataloader(
        dataset=streaming_text_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
    )


def parse_args() -> Namespace:
    """Parse command-line arguments for the tokenizer training script.

    This function sets up an argument parser to receive various parameters for training
    a tokenizer on tokenized streaming text datasets. It parses the command-line
    arguments and returns them as a Namespace object.

    Parameters
    ----------
    None

    Returns
    -------
    Namespace
        An object containing the parsed command-line arguments:
        - tokenizer (str): Path or name of the tokenizer to use.
        - tokenizer_kwargs (dict): Additional keyword arguments for the tokenizer.
        - decode_tokenizer (str): Path or name of the decode tokenizer to use.
        - decode_tokenizer_kwargs (dict): Additional keyword arguments for the decode
            tokenizer.
        - splits (list[str]): List of dataset splits to process (e.g., "train" or
            "validation").
        - streams_config_file (str): Path to the streams configuration file.
        - dataset_config_file (str): Path to the dataset configuration file.
        - max_length (int): Maximum length of the sequences. Default is 2048.
        - special_tokens (list[str]): List of special tokens (e.g., "<unk>" or "</s>").
        - output_folder_suffix (str): Suffix for the output folder name.
        - vocab_size (int): Size of the vocabulary.
        - batch_size (int): Batch size for the DataLoader. Default is 512.
        - truncate_num_batches (int): Number of batches to truncate to. Default is -1
            (no truncation).
        - num_workers (int): Number of worker processes to use for data loading.


    Example
    -------
    >>> args = parse_args()
    >>> print(args.tokenizer)
    >>> print(args.tokenizer_kwargs)
    >>> print(args.decode_tokenizer)
    >>> print(args.decode_tokenizer_kwargs)
    >>> print(args.splits)
    >>> print(args.streams_config_file)
    >>> print(args.dataset_config_file)
    >>> print(args.max_length)
    >>> print(args.special_tokens)
    >>> print(args.output_folder_suffix)
    >>> print(args.vocab_size)
    >>> print(args.batch_size)
    >>> print(args.truncate_num_batches)
    >>> print(args.num_workers)

    """
    parser = ArgumentParser(
        description=("Train Tokenizer from Tokenized Streaming Text Datasets."),
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        default=None,
        required=True,
    )
    parser.add_argument("--tokenizer_kwargs", type=str, required=False)
    parser.add_argument(
        "--decode_tokenizer",
        type=str,
        default=None,
        required=True,
    )
    parser.add_argument("--decode_tokenizer_kwargs", type=str, required=False)
    parser.add_argument(
        "--splits",
        nargs="+",
        default=None,
        help='E.g. "train" or "validation"',
        required=True,
    )
    parser.add_argument(
        "--streams_config_file",
        type=str,
        default=None,
        required=True,
    )
    parser.add_argument(
        "--dataset_config_file",
        type=str,
        default=None,
        required=True,
    )
    parser.add_argument("--max_length", type=int, default=2048)
    parser.add_argument(
        "--special_tokens",
        nargs="+",
        default=None,
        help='E.g. "<unk>" or "</s>"',
    )
    parser.add_argument("--output_folder_suffix", type=str, required=True)
    parser.add_argument("--vocab_size", type=int, default=None, required=True)
    parser.add_argument("--batch_size", type=int, required=False, default=512)
    parser.add_argument("--truncate_num_batches", type=int, default=-1, required=False)
    parser.add_argument("--num_workers", type=int, required=False, default=None)
    parsed = parser.parse_args()

    # Parse tokenizer_kwargs
    if parsed.tokenizer_kwargs is not None:
        parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs)
    else:
        parsed.tokenizer_kwargs = {}
    parsed.tokenizer_kwargs["model_max_length"] = parsed.max_length

    # Parse decode_tokenizer_kwargs
    if parsed.decode_tokenizer_kwargs is not None:
        parsed.decode_tokenizer_kwargs = json.loads(parsed.decode_tokenizer_kwargs)
    else:
        parsed.decode_tokenizer_kwargs = {}
    parsed.decode_tokenizer_kwargs["model_max_length"] = parsed.max_length
    if parsed.splits is not None:
        parsed.splits = set(parsed.splits)
    if parsed.special_tokens is not None:
        parsed.special_tokens = list(set(parsed.special_tokens))
    log(INFO, "Arguments parsed: %s", parsed)
    return parsed


def main(args: Namespace) -> None:
    """Train a tokenizer on tokenized decoded datasets and save the trained tokenizer.

    This function processes specified splits of a dataset, trains a
    SentencePieceUnigramTokenizer on the tokenized streaming text dataset, converts it
    to a fast tokenizer, and saves the trained tokenizer to the specified output
    directory. The tokenizer is also pushed to the Hugging Face Hub.

    Parameters
    ----------
    args : Namespace
        The arguments for the function, expected to have the following attributes:
        - splits (list[str]): List of dataset splits to process (e.g., "train" or
            "val").
        - tokenizer (str): Path or name of the tokenizer to use for training.
        - tokenizer_kwargs (dict): Additional keyword arguments for the tokenizer.
        - decode_tokenizer (str): Path or name of the decode tokenizer to use for
            decoding the dataset.
        - decode_tokenizer_kwargs (dict): Additional keyword arguments for the decode
            tokenizer.
        - streams_config_file (str): Path to the streams configuration file.
        - dataset_config_file (str): Path to the dataset configuration file.
        - max_length (int): Maximum length of the sequences.
        - vocab_size (int): Size of the vocabulary.
        - special_tokens (list[str]): List of special tokens (e.g., "<unk>" or "</s>").
        - truncate_num_batches (int): Number of batches to truncate to.
        - num_workers (int): Number of worker processes to use for data loading.
        - output_root_dir (str): Directory where the output will be saved.
        - output_folder_suffix (str): Suffix for the output folder name.

    Example
    -------
    >>> from argparse import Namespace
    >>> args = Namespace(
    ...     splits=["train", "val"],
    ...     decode_tokenizer="path/to/decode_tokenizer",
    ...     decode_tokenizer_kwargs={},
    ...     streams_config_file="path/to/streams_config.yaml",
    ...     dataset_config_file="path/to/dataset_config.yaml",
    ...     max_length=512,
    ...     vocab_size=32000,
    ...     special_tokens=["<unk>", "</s>"],
    ...     truncate_num_batches=1000,
    ...     num_workers=4,
    ...     output_root_dir="output_dir",
    ...     output_folder_suffix="suffix",
    ... )
    >>> main(args)

    """
    # Set the sharing strategy for multiprocessing
    torch.multiprocessing.set_sharing_strategy("file_system")
    # Create temporary directory for general use
    general_temp_dir = TemporaryDirectory()
    # Set the environment variables for the temporary directory and the S3 endpoint URL
    os.environ["TMPDIR"] = general_temp_dir.name
    os.environ["RUN_UUID"] = str(uuid4())
    # Loop over the names and splits
    for split in args.splits:
        assert split in {"train", "val"}, f"Split {split} not in {'train', 'val'}"
        start_time = time.time()
        split_temp_dir = TemporaryDirectory()
        # Build the tokenizer to use for decoding the dataset
        decode_tokenizer = build_tokenizer(
            args.decode_tokenizer,
            args.decode_tokenizer_kwargs,
        )
        # Create the DataLoader for StreamingTextDataset for training the tokenizer
        stream_name, stream_dict = elaborate_stream_yamls(
            args.streams_config_file,
            args.dataset_config_file,
            split,
            args.max_length,
            split_temp_dir,
        )
        loader = create_streaming_text_dataloader(
            streams=build_streams({stream_name: stream_dict}),
            decode_tokenizer=decode_tokenizer,
            max_length=args.max_length,
            num_workers=args.num_workers,
            batch_size=args.batch_size,
        )

        # Create a tokenizer object to be trained
        tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs)
        assert isinstance(
            tokenizer,
            PreTrainedTokenizerFast,
        ), "Tokenizer must be a PreTrainedTokenizerFast for train it anew."
        log(INFO, "Training tokenizer %s", tokenizer)
        # Train the tokenizer
        tokenizer = tokenizer.train_new_from_iterator(
            text_iterator=(
                stream_and_untokenize(
                    loader,
                    decode_tokenizer,
                    truncate_num_batches=args.truncate_num_batches,
                )
            ),
            vocab_size=args.vocab_size,
            show_progress=True,
            new_special_tokens=args.special_tokens,
        )
        log(INFO, "Tokenizer has trained")

        # Dump tokenizer files to temporary directory and schedule upload
        project_path = os.environ.get("PROJECT_PATH", "")
        tokenizer_filename = Path(
            f"{project_path}/{TOKENIZER_FOLDER_NAME}_{args.output_folder_suffix}",
        )
        tokenizer_filename.mkdir(parents=True, exist_ok=True)
        tokenizer.save_pretrained(tokenizer_filename)
        log(INFO, "Time elapsed: %s", time.time() - start_time)
        split_temp_dir.cleanup()
    general_temp_dir.cleanup()


if __name__ == "__main__":
    main(parse_args())
