# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.

"""Utilities for Retro preprocessing."""

import glob
import logging
import os
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict

import numpy as np
import torch
from torch.distributed import ProcessGroup

from megatron.core import parallel_state
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.query.multi_split_gpt_dataset import (
    MultiSplitGPTDataset,
    MultiSplitGPTDatasetConfig,
)
from megatron.core.utils import log_single_rank

logger = logging.getLogger(__name__)

try:
    from tqdm import tqdm

    HAVE_TQDM = True
except ImportError:
    HAVE_TQDM = False

try:
    import h5py

    HAVE_H5PY = True
except ImportError:
    HAVE_H5PY = False


class Block(TypedDict):
    """Specific block arg type to mute mypy."""

    range: Tuple[int, int]
    path: str


def log_retro_rank_0(message: str) -> None:
    """Log on rank 0.

    Args:
        message (str): Message to log.
    """
    log_single_rank(logger, logging.INFO, "[RETRO] " + message)


def retro_makedir(config: RetroPreprocessingConfig, path: str) -> None:
    """Make a directory, conditional on not being in validation mode.

    Args:
        config (RetroPreprocessingConfig): Retro preprocessing config.
        path (str): Path to directory.
    """
    if config.retro_task_validate is None:
        os.makedirs(path, exist_ok=True)


def extract_data_config(config: RetroPreprocessingConfig) -> MultiSplitGPTDatasetConfig:
    """Extract data config from dataset.

    Args:
        config (RetroPreprocessingConfig): Retro preprocessing config.

    Returns:
        The config object used to build the dataset.
    """
    return config.retro_gpt_chunk_datasets.train["dataset"].sample_dataset.config


def get_num_chunks_per_sample(sample_length: int, chunk_length: int) -> int:
    """Compute seq_length // chunk_length.

    Args:
        sample_length (int): Alias of `sequence_length`.
        chunk_length (int): Retro chunk length (e.g., 64).

    Returns:
        Number of chunks per sample (i.e., `sequence_length` / `chunk_length`).
    """
    assert sample_length % chunk_length == 0
    return sample_length // chunk_length


class GPTToTextDataset(torch.utils.data.Dataset):
    """Dataset to convert GPT tokens to text.

    Args:
        gpt_dataset (MultiSplitGPTDataset): GPT dataset, which outputs GPT token samples.
        gpt_tokenizer (Any): GPT tokenizer.
    """

    def __init__(self, gpt_dataset: MultiSplitGPTDataset, gpt_tokenizer: Any):
        super().__init__()

        self.gpt_dataset = gpt_dataset
        self.gpt_tokenizer = gpt_tokenizer

    def __len__(self) -> int:
        """Dataset length.

        Returns:
            Number of samples in the dataset.
        """
        return len(self.gpt_dataset)

    def __getitem__(self, idx: int) -> dict:
        """Get dataset sample.

        Args:
            idx (int): Index of sample.

        Returns:
            A dict containing attribute 'text' of type string.
        """
        gpt_token_ids = self.gpt_dataset[idx]["text"].tolist()
        text = self.gpt_tokenizer.detokenize(gpt_token_ids)
        return {"text": text}


def get_blocks(
    dirname: str, n_samples: int, block_size: int, validate: Optional[Callable] = None
) -> SimpleNamespace:
    """Divide range [0, num_samples) to sequence of block ranges.

    This is a core method within the concept of block processing. The idea
    is to divide a range (size n_samples) into a sequence of blocks. Each
    block corresponds to a file within 'dirname' with name
    '{start_idx}-{end_idx}.hdf5'. This method checks for the existence of
    these files, and returns two lists, one for existing blocks and one for
    missing blocks.

    Args:
        dirname (str): Path to directory containing block files.
        n_samples (int): Ideal number of samples.
            The total number of saved block data is <=n_samples.
        block_size (int): Max number of samples per block file (e.g., 100000).
        validate (Callable): Method for validating each block file during load.

    Returns:
        A namespace consisting of 2 lists: existing blocks, and missing blocks.
        The total number of samples between the existing and missing blocks should
        equal n_samples above.
    """

    if not HAVE_TQDM:
        raise ImportError("tqdm is required to use the RetroDataset. Please install tqdm.")

    if not HAVE_H5PY:
        raise ImportError("h5py is required to use the RetroDataset. Please install h5py.")

    assert os.path.isdir(dirname), "missing directory '%s.'" % dirname

    # Block ranges.
    block_start_idxs = list(range(0, n_samples, block_size))
    block_end_idxs = [min(n_samples, i + block_size) for i in block_start_idxs]
    block_ranges = list(zip(block_start_idxs, block_end_idxs))

    # All block files (existing + missing).
    n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1)

    all_blocks: List[Block] = [
        {
            "range": r,
            "path": os.path.join(
                dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r])
            ),
        }
        for r in block_ranges
    ]
    all_block_path_set = set(block["path"] for block in all_blocks)

    # Validate function.
    validate = (lambda f: None) if validate is None else validate

    # Delete corrupt files.
    if torch.distributed.get_rank() == 0:
        existing_block_paths = [
            block["path"] for block in all_blocks if os.path.exists(block["path"])
        ]
        for index, path in enumerate(tqdm(existing_block_paths, "validating block.")):
            assert path in all_block_path_set, "unexpected filename, '%s'." % path

            try:
                f = h5py.File(path, "r")
            except Exception:
                os.remove(path)
                continue

            try:
                validate(f)
            except Exception:
                os.remove(path)
            finally:
                f.close()

    # Wait for files to be deleted.
    torch.distributed.barrier()

    # Collect blocks.
    blocks = SimpleNamespace(
        existing=[b for b in all_blocks if os.path.exists(b["path"])],
        missing=[b for b in all_blocks if not os.path.exists(b["path"])],
    )

    return blocks


def get_blocks_by_rank(
    dirname: str,
    n_samples: int,
    block_size: int,
    validate: Optional[Callable] = None,
    sample: Optional[float] = None,
    process_group: Optional[ProcessGroup] = None,
) -> SimpleNamespace:
    """Divide existing and missing blocks evenly across all ranks.

    See 'get_blocks()' above for description. The returned lists of existing and
    missing blocks are split evenly across ranks via interleaving. This way,
    each rank has a roughly equal number of blocks to process for a
    downstream operation.

    Args:
        dirname (str): Path to directory containing block files.
        n_samples (int): Ideal number of samples. The total number of saved block data
            is <=n_samples.
        block_size (int): Max number of samples per block file (e.g., 100000).
        validate (Callable): Method for validating each block file during load.
        sample (Optional[float]): If provided, sample a random subset of the blocks.
            Used for validating preprocessing correctness.
        process_group (Optional[ProcessGroup]): Process group for distributed operations.
            If None, uses data parallel group.

    Returns:
        A namespace consisting of 2 lists: existing blocks, and missing blocks.
        Each of these two lists is potentially a sub-sample of the total set of
        existing and missing blocks, depending on whether sampling is used.
        Additionally, the attributes n_existing_world and n_missing_world are the
        total number of existing and missing blocks, independent of samples.
        Therefore, (n_existing_world + n_missing_world) * block_size == n_samples.
    """

    if process_group is None:
        process_group = parallel_state.get_data_parallel_group()

    # Get world blocks.
    blocks = get_blocks(dirname, n_samples, block_size, validate)

    # This rank's existing and missing files.
    rank_existing_blocks = blocks.existing[
        process_group.rank() : len(blocks.existing) : process_group.size()
    ]
    rank_missing_blocks = blocks.missing[
        process_group.rank() : len(blocks.missing) : process_group.size()
    ]

    # Extend rank's existing and missing blocks (with None) such that all ranks
    # have equal length lists. This allows for easier tracking of global progress.
    def get_world_max(n: int) -> int:
        """Get max value across ranks.

        Args:
            n (int): Value on this rank.

        Returns:
            Max value across all ranks.
        """
        n_tensor = torch.cuda.LongTensor([n])
        torch.distributed.all_reduce(n_tensor, op=torch.distributed.ReduceOp.MAX)
        return n_tensor.item()

    max_n_existing = get_world_max(len(rank_existing_blocks))
    max_n_missing = get_world_max(len(rank_missing_blocks))

    rank_existing_blocks += [None] * (max_n_existing - len(rank_existing_blocks))
    rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks))

    # Collect blocks.
    blocks = SimpleNamespace(
        n_existing_world=len(blocks.existing),
        n_missing_world=len(blocks.missing),
        existing=rank_existing_blocks,
        missing=rank_missing_blocks,
    )

    if sample is not None:
        # Sample existing and missing blocks evenly across all ranks. The
        # returned lists of blocks are randomly sampled (without replacement)
        # to yield `sample * len(blocks)` number of blocks.

        # Randomly sample blocks.
        def sample_blocks(_blocks: List[Optional[Dict]]) -> List[Optional[Dict]]:
            """Sample a random subset of all blocks.

            Args:
                _blocks (List[Optional[Dict]]): List of all blocks.

            Returns:
                A random subset of the blocks.
            """
            n_blocks_sample = int(np.ceil(sample * len(_blocks)))
            sampled_blocks: List[Optional[Dict]] = [b for b in _blocks if b is not None]

            np.random.seed(None)
            np.random.shuffle(sampled_blocks)

            sampled_blocks = sampled_blocks[:n_blocks_sample]
            sampled_blocks += [None] * (n_blocks_sample - len(sampled_blocks))

            return sampled_blocks

        blocks.existing = sample_blocks(blocks.existing)
        blocks.missing = sample_blocks(blocks.missing)

    return blocks


class BlockPathMap:
    """Map an index to its containing block path.

    The common use for this class is to have a directory of files containing
    blocks of processed data, of uniform block size (e.g., 100k samples per
    file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]',
    where 'endIdx' minus 'startIdx' must equal the block size, with the possible
    exception of the final block. Given an input index, this class maps the
    index to the containing block file.

    Args:
        block_paths (List[str]): List of paths to saved block files.
        block_size (int): Max number of samples per block file (e.g., 100000).
    """

    @classmethod
    def from_dir(cls, dir: str, block_size: int, ext: str = "hdf5") -> Any:
        """Get list of block files, and create map.

        Args:
            dir (str): Path to directory containing saved block files.
            block_size (int): Max number of samples per block file (e.g., 100000).
            ext (str): Block file extension (e.g., 'hdf5').

        Returns:
            A mapping of sample index to block file path.
        """
        assert os.path.isdir(dir), f"directory not found, '{dir}'."
        return cls(sorted(glob.glob(dir + f"/*.{ext}")), block_size)

    def __init__(self, block_paths: List[str], block_size: int):
        self.max_idx = 0
        self.block_path_map = {}
        for block_path in block_paths:
            name = os.path.splitext(os.path.basename(block_path))[0]
            start_idx, end_idx = [int(i) for i in name.split("-")]
            self.block_path_map[start_idx] = block_path
            self.max_idx = max(self.max_idx, end_idx)
        self.block_size = block_size

    def __str__(self) -> str:
        """Stringify the mapping.

        Returns:
            A string representation of this block path map.
        """
        return "%d paths" % len(self.block_path_map)

    def __getitem__(self, idx: int) -> str:
        """Get block path from index.

        Args:
            idx (int): Index of sample.

        Returns:
            The path to the block file containing the sample index.
        """
        block_start_idx = self.block_size * (idx // self.block_size)
        block_path = self.block_path_map[block_start_idx]
        return block_path
