# Build based on the original code from Lightning AI
# litgpt/packed_dataset.py

# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

# Very loosely inspired by indexed_dataset in Fairseq, Megatron
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py


import random
import hashlib
from torch.utils.data import IterableDataset, get_worker_info
from pathlib import Path

# We will build v0 assuming that the dataset is already saved to disk
# in standard hf format. This leaves room for preproc ops as separate logic.
# basic assumpution will be "text" field only.
from datasets import load_from_disk, DatasetDict, concatenate_datasets, Dataset, load_dataset
from datasets import IterableDataset as HFIterableDataset

import logging
import glob

import torch
import pyarrow.parquet as pq


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class ParquetStreamPure(IterableDataset):
    """datasets-free version of (mostly) the same thing - shuffle not across files though
    a bit ironic to keep it in this file"""

    def __init__(
        self,
        dataset_folder_path="",
        seed=12345,
        shuffle=True,
        num_processes=1,
        process_rank=0,
        torch_device=None,
        prefix="",
        verbose=False,
        shuffle_filenames=True,
        data_signature: dict[str, list[str] | str] = {"keys": ["input_ids"]},
        repetitions=None,
        return_data_id=False,
        data_id=None,
        broadcast_glob=True,
        doc_wise_pqdsp=False,
        doc_wise_pqdsp_skip_tail=True,
        doc_wise_pqdsp_sep_tok=None,
    ):
        self.process_rank = process_rank
        self.num_processes = num_processes
        self.torch_device = torch_device

        self.doc_wise_pqdsp = doc_wise_pqdsp
        self.doc_wise_pqdsp_skip_tail = doc_wise_pqdsp_skip_tail
        self.doc_wise_pqdsp_sep_tok = doc_wise_pqdsp_sep_tok
        if self.doc_wise_pqdsp:
            assert self.doc_wise_pqdsp_sep_tok is not None, "Need a separator id for doc-wise pqdsp"

        # Get file list, with distributed broadcast if needed
        if broadcast_glob and torch.distributed.is_initialized():
            if process_rank == 0:
                filenames = sorted(str(p) for p in Path(dataset_folder_path).glob(f"{prefix}*.parquet"))
            else:
                filenames: list[str] = None  # type: ignore # believe
            obj = [filenames]
            torch.distributed.broadcast_object_list(obj, 0)
            parquet_files = obj[0]
        else:
            parquet_files = sorted(str(p) for p in Path(dataset_folder_path).glob(f"{prefix}*.parquet"))
        if shuffle_filenames:
            random.Random(seed).shuffle(parquet_files)

        # Shard files for distributed training
        self.parquet_files = parquet_files[process_rank::num_processes] if num_processes > 1 else parquet_files
        self._ds_fingerprint = hashlib.shake_128(str(self.parquet_files).encode()).hexdigest(4)

        self.verbose = verbose
        if self.verbose:
            logger.info(
                f"Rank {process_rank}/{num_processes} has {len(self.parquet_files)} parquet files | identifier={self._ds_fingerprint}"
            )
            examples = pq.read_table(self.parquet_files[0], columns=["input_ids"]).slice(0, 3).to_pylist()  # Get 3 rows
            for i, example in enumerate(examples):
                logger.info(f"Example {i}: {example['input_ids'][:12]}")  # First 12 tokens of each row
        self.shuffle = shuffle
        self.seed = seed
        # Initialize default state
        self._state = {
            "rng": random.Random(seed),
            "rng_state": (-1, [-1], None),
            "buffer": [],
            "file_idx": 0,
            "row_group_idx": 0,
            "fingerprint": self._ds_fingerprint,
        }

    def __iter__(self):
        while self._state["file_idx"] < len(self.parquet_files):
            if not self._state["buffer"]:
                # Refill buffer from current position
                pf = pq.ParquetFile(self.parquet_files[self._state["file_idx"]])
                if self._state["row_group_idx"] >= pf.num_row_groups:
                    self._state["file_idx"] += 1
                    self._state["row_group_idx"] = 0
                    continue

                if self.verbose:
                    # proactively warn if loading the last row group of the last file
                    if (self._state["file_idx"] == len(self.parquet_files) - 1) and (
                        self._state["row_group_idx"] == pf.num_row_groups - 1
                    ):
                        print(f"Loading the last row group of the last file!", flush=True)

                    print(
                        f"Reading new row group | file_idx:{self._state['file_idx']}, row_group_idx:{self._state['row_group_idx']}",
                        flush=True,
                    )

                self._read_buffer(pf)
                self._state["row_group_idx"] += 1

            while self._state["buffer"]:
                yield torch.as_tensor(self._state["buffer"].pop(), dtype=torch.long)

    def _read_buffer(self, parquet_file):
        batch = parquet_file.read_row_group(self._state["row_group_idx"])
        self._state["buffer"] = batch.column("input_ids").to_pylist()

        if self.doc_wise_pqdsp:
            doc_wise_buffer = []
            for row in self._state["buffer"]:
                curr_sep_idx = 0
                while curr_sep_idx < len(row):
                    try:
                        next_sep_idx = row.index(self.doc_wise_pqdsp_sep_tok, curr_sep_idx + 1)
                    except ValueError:
                        # if we can't find another sep, were either the tail,
                        # or if curr is 0, we're the head==only doc, and so we dont skip that
                        if self.doc_wise_pqdsp_skip_tail and curr_sep_idx > 0:
                            break
                        next_sep_idx = len(row)
                    doc_wise_buffer.append(row[curr_sep_idx:next_sep_idx])
                    curr_sep_idx = next_sep_idx

            self._state["buffer"] = doc_wise_buffer

        if self.shuffle:
            self._state["rng_state"] = self._state["rng"].getstate()  # the last used state for a shuffle op
            self._state["rng"].shuffle(self._state["buffer"])

    def state_dict(self):

        if self.verbose:
            print(f"({self.process_rank}/{self.num_processes}) BEGIN pqds-pure state_dict function.", flush=True)

        if self.shuffle:
            # rng state has three parts , one int, one list, and another int or None
            rng_0, rng_1, rng_2 = self._state["rng_state"]
            rank_rng_0 = torch.tensor([rng_0], device=self.torch_device)
            rank_rng_1 = torch.tensor(rng_1, device=self.torch_device)
            rank_rng_2 = torch.tensor([rng_2] if rng_2 is not None else [-1], device=self.torch_device)
            # make world size containers for each part
            all_rank_rng_0 = [torch.zeros_like(rank_rng_0) for _ in range(self.num_processes)]
            all_rank_rng_1 = [torch.zeros_like(rank_rng_1) for _ in range(self.num_processes)]
            all_rank_rng_2 = [torch.zeros_like(rank_rng_2) for _ in range(self.num_processes)]
            # gather the parts
            torch.distributed.all_gather(all_rank_rng_0, rank_rng_0)
            if self.verbose:
                print(
                    f"({self.process_rank}/{self.num_processes}) state_dict function: passed rng0 gather",
                    flush=True,
                )
            torch.distributed.all_gather(all_rank_rng_1, rank_rng_1)
            if self.verbose:
                print(
                    f"({self.process_rank}/{self.num_processes}) state_dict function: passed rng1 gather",
                    flush=True,
                )
            torch.distributed.all_gather(all_rank_rng_2, rank_rng_2)
            if self.verbose:
                print(
                    f"({self.process_rank}/{self.num_processes}) state_dict function: passed rng2 gather",
                    flush=True,
                )
            # pack them up
            all_rank_rng_states = (all_rank_rng_0, all_rank_rng_1, all_rank_rng_2)
        else:
            all_rank_rng_states = None

        if self.verbose:
            print(f"({self.process_rank}/{self.num_processes}) state_dict function: passed rng gathers", flush=True)

        # we also need to save independent row_indices for each worker
        row_idx = torch.tensor([len(self._state["buffer"])], device=self.torch_device)
        all_row_indices = [torch.zeros_like(row_idx) for _ in range(self.num_processes)]
        torch.distributed.all_gather(all_row_indices, row_idx)
        all_row_indices = [int(ari.item()) for ari in all_row_indices]

        if self.verbose:
            print(f"({self.process_rank}/{self.num_processes}) state_dict function: passed idx gathers", flush=True)

        # we also need to save the file_idx
        file_idx = torch.tensor([self._state["file_idx"]], device=self.torch_device)

        # for row_group_idx we need to
        # sub 1 to reload the currently buffer'd row group on resume (rather than the next)
        # but only do it on a copy of the value
        row_group_idx = torch.tensor([self._state["row_group_idx"]], device=self.torch_device)
        if len(self._state["buffer"]) > 0:
            row_group_idx -= 1

        all_file_indices = [torch.zeros_like(file_idx) for _ in range(self.num_processes)]
        all_row_group_indices = [torch.zeros_like(row_group_idx) for _ in range(self.num_processes)]
        torch.distributed.all_gather(all_file_indices, file_idx)
        torch.distributed.all_gather(all_row_group_indices, row_group_idx)
        all_file_indices = [int(afi.item()) for afi in all_file_indices]
        all_row_group_indices = [int(argi.item()) for argi in all_row_group_indices]
        # sub 1 to reload the currently buffer'd row group on resume (rather than the next)
        # all_row_group_indices = [argi - 1 for argi in all_row_group_indices]
        if self.verbose:
            print(
                f"({self.process_rank}/{self.num_processes}) state_dict function: passed file and row group gathers",
                flush=True,
            )

        # and finally the fingerprint
        fingerprint_tensor = torch.tensor([int(self._ds_fingerprint, 16)], device=self.torch_device)
        all_fingerprints = [torch.zeros_like(fingerprint_tensor) for _ in range(self.num_processes)]
        torch.distributed.all_gather(all_fingerprints, fingerprint_tensor)
        all_fingerprints = [hex(int(af.item()))[2:] for af in all_fingerprints]
        if self.verbose:
            print(
                f"({self.process_rank}/{self.num_processes}) state_dict function: passed fingerprint gathers",
                flush=True,
            )

        if self.verbose:
            print(f"({self.process_rank}/{self.num_processes}) END pqds-pure state_dict function.", flush=True)
        return {
            "file_idx": all_file_indices,
            "row_group_idx": all_row_group_indices,
            "row_idx": all_row_indices,
            "rng_state": all_rank_rng_states,
            "fingerprint": all_fingerprints,
        }

    def load_state_dict(self, state_dict):
        if self.verbose:
            print(f"BEGIN pqds-pure load_state_dict function.", flush=True)

        # Unpack fingerprint
        fingerprint = state_dict.get("fingerprint")[torch.distributed.get_rank()]
        if int(fingerprint, 16) != int(self._ds_fingerprint, 16):
            raise ValueError(
                f"Dataset fingerprint mismatch. Expected {self._ds_fingerprint}, "
                f"got {state_dict.get('fingerprint')}. This may indicate attempting to "
                "load a state from a different dataset."
            )
        if self.verbose:
            print(f"Dataset fingerprint match: {fingerprint}=={self._ds_fingerprint}", flush=True)
        # Unpack the file_idx and row_group_idx states
        self._state["file_idx"] = state_dict["file_idx"][torch.distributed.get_rank()]
        self._state["row_group_idx"] = state_dict["row_group_idx"][torch.distributed.get_rank()]
        if self._state["row_group_idx"] == -1:  # this is a special case for the first row group
            self._state["row_group_idx"] = 0

        if self.verbose:
            print(f"file_idx: {self._state['file_idx']}, row_group_idx: {self._state['row_group_idx']}", flush=True)

        # Validate file_idx bounds
        if not (0 <= self._state["file_idx"] < len(self.parquet_files)):
            raise ValueError(
                f"Invalid file_idx {self._state['file_idx']}. " f"Must be between 0 and {len(self.parquet_files)-1}"
            )

        if self.verbose:
            print(f"file_idx bounds check passed", flush=True)

        # Load the current file
        pf = pq.ParquetFile(self.parquet_files[self._state["file_idx"]])

        if self.verbose:
            print(f"loaded parquet file: {self.parquet_files[self._state['file_idx']]}", flush=True)

        # Reload RNG before trying to fill the buffer as we may shuffle within the read fn
        if state_dict["rng_state"] is not None:
            assert self.shuffle, "RNG state provided but shuffle is disabled, potential resume mismatch."
            # now we unpack the world of rng states
            all_rank_rng_states = state_dict["rng_state"]
            all_rank_rng_0, all_rank_rng_1, all_rank_rng_2 = all_rank_rng_states
            rank_rng_0 = all_rank_rng_0[torch.distributed.get_rank()]
            rank_rng_1 = all_rank_rng_1[torch.distributed.get_rank()]
            rank_rng_2 = all_rank_rng_2[torch.distributed.get_rank()]
            self._state["rng_state"] = (
                rank_rng_0.item(),
                tuple(rank_rng_1.tolist()),
                rank_rng_2.item() if rank_rng_2.item() != -1 else None,
            )

            self._state["rng"] = random.Random()
            self._state["rng"].setstate(self._state["rng_state"])

        if self.verbose:
            print(f"loaded rng state", flush=True)

        # Then actually reload the token buffer
        self._read_buffer(pf)

        if self.verbose:
            print(f"reloaded buffer: len={len(self._state['buffer'])}", flush=True)

        # Then get the row_idx
        all_row_indices = state_dict["row_idx"]
        row_idx = all_row_indices[torch.distributed.get_rank()]

        if self.verbose:
            print(f"reloaded row_idx: {row_idx}", flush=True)

        # to trim the buffer down to point we previously consumed through
        self._state["buffer"] = self._state["buffer"][:row_idx]

        if self.verbose:
            print(f"trimmed buffer: len={len(self._state['buffer'])}", flush=True)
        if self.verbose:
            print(f"END pqds-pure load_state_dict function.", flush=True)


class HuggingfaceDataset(IterableDataset):
    def __init__(
        self,
        ds_name_or_path=None,
        seed=12345,
        shuffle=False,
        num_processes=1,
        process_rank=0,
        data_id=None,
        data_signature: dict[str, list[str] | str] = {"keys": ["text"], "format_fn": "pass_text"},
        repetitions=None,
        return_data_id=False,
    ):
        assert ds_name_or_path is not None
        self._ds_name_or_path = ds_name_or_path
        self._seed = seed
        assert not shuffle, "Shuffle not implemented for hfds."
        self._num_processes = num_processes
        self._process_rank = process_rank
        self._data_id = data_id  # This is human readble, the mixture unit
        self._return_data_id = return_data_id
        self._ds_fingerprint = (
            None  # This is not human readable, corresp to the subset of work _this_ process is handling.
        )
        self._data_signature = data_signature
        self._ds_total_length = None
        self._ds_length = None
        self._subds = None
        self._ds_min = None
        self._ds_max = None

        # Here is where we load the dataset from disk (whole thing, but just the memmap ofc)
        if repetitions is not None:
            ds_list = [load_from_disk(ds_name_or_path) for _ in range(repetitions)]
            self._ds: Dataset = concatenate_datasets(ds_list)  # type: ignore
        else:
            self._ds: Dataset = load_from_disk(ds_name_or_path)  # type: ignore

        assert not isinstance(
            self._ds, DatasetDict
        ), "Dataset path should point to a single split, try adding /train ?."

        self._ds_total_length = len(self._ds)

    def __iter__(self):  # type: ignore
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        num_shards = num_workers * self._num_processes
        shard_id = self._process_rank * num_workers + worker_id

        # This is where we shard the dataset into work for each dataparallel rank.
        # Our unit of work is now a "row" of the dataset though, not a file.

        self._worker_id = worker_id

        # max_num_rows = (len(self._ds) // num_shards) * num_shards
        max_num_rows = len(self._ds)
        index_list = list(range(shard_id, max_num_rows, num_shards))

        if index_list == []:
            self._ds_fingerprint = None
            self._ds_min = 0
            self._ds_max = 0
        else:
            self._ds_fingerprint = hashlib.shake_128(str(index_list).encode()).hexdigest(4)
            self._ds_min = min(index_list)
            self._ds_max = XXXX-13(index_list)

        subds = self._ds.select(index_list)
        self._subds = subds

        self._ds_length = len(self._subds)

        logger.info(
            f"Rank {self._process_rank}/{self._num_processes}, worker {worker_id} has "
            f"{self._ds_length}/{self._ds_total_length} rows | identifier={self._data_id}:{self._ds_fingerprint} "
            f"| range={self._ds_min}:{self._ds_max} | head={index_list[:3]} | tail={index_list[-3:]}"
        )

        return HuggingfaceDatasetIterator(
            ds=subds,
            data_signature=self._data_signature,
            data_id=self._data_id,
            return_data_id=self._return_data_id,
            fingerprint=self._ds_fingerprint,
            worker_id=worker_id,
            process_rank=self._process_rank,
            num_processes=self._num_processes,
        )

    def __len__(self):
        return self._ds_length


class HuggingfaceDatasetIterator:
    def __init__(
        self,
        ds,
        data_signature: dict[str, list[str] | str],
        data_id=None,
        return_data_id=None,
        fingerprint=None,
        worker_id=None,
        process_rank=None,
        num_processes=None,
    ):
        self._ds = ds
        self._data_signature = data_signature
        self._data_id = data_id
        self._return_data_id = return_data_id
        self._ds_fingerprint = fingerprint
        self._worker_id = worker_id
        self._process_rank = process_rank
        self._num_processes = num_processes

        self._ds_iter = None

    def __len__(self):
        return len(self._ds)

    def __next__(self):
        if self._ds_iter is None:
            self._ds_iter = iter(self._ds)

        row = next(self._ds_iter)

        # the data signature tells us what keys to extract from the row
        row = {k: row[k] for k in self._data_signature["keys"]}
        # then we attach the data_signature to the sample to support
        # heterogeneously sourced batches in the collate_fn
        row["data_signature"] = self._data_signature

        if self._return_data_id:
            row["data_id"] = self._data_id

        return row