# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import random
from dataclasses import dataclass

from functools import partial
from typing import Any, Callable, Optional

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from huggingface_hub import HfApi
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset

from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.config import JobConfig
from torchtitan.tools.logging import logger


def _load_c4_dataset(dataset_path: str, split: str):
    """Load C4 dataset with default configuration."""
    return load_dataset(dataset_path, name="en", split=split, streaming=True)


def _process_c4_text(sample: dict[str, Any]) -> str:
    """Process C4 dataset sample text."""
    return sample["text"]


def _process_fineweb_text(sample: dict[str, Any]) -> str:
    """Process fineweb dataset sample text."""
    return sample["text"]


def _load_dclm_dataset(dataset_path: str):
    """Load fineweb dataset with default configuration."""
    return load_dataset("mlfoundations/dclm-baseline-1.0", name="default", split="train", streaming=True)

def _process_dclm_text(sample: dict[str, Any]) -> str:
    """Process fineweb dataset sample text."""
    return sample["text"]

def _load_olmo_dataset(dataset_path: str):
    """Load fineweb dataset with default configuration."""
    return load_dataset("allenai/olmo-mix-1124", name="default", split="train", streaming=True)

def _process_olmo_text(sample: dict[str, Any]) -> str:
    """Process fineweb dataset sample text."""
    return sample["text"]


def _load_fineweb_dataset(dataset_path: str):
    """Load fineweb dataset with default configuration."""
    return load_dataset("HuggingFaceFW/fineweb", split="train", streaming=True)


def _load_fineweb_edu_dataset(dataset_path: str):
    """Load fineweb dataset with default configuration."""
    return load_dataset("HuggingFaceFW/fineweb-edu", split="train", streaming=True)


class UniformSamplingDataset(Stateful):
    """
    A wrapper that provides uniform sampling across multiple streaming datasets.

    This class mimics the interface of a single streaming dataset while internally
    managing multiple datasets and sampling uniformly from them in round-robin fashion.

    Args:
        datasets: List of streaming datasets to sample from uniformly
    """

    def __init__(self, datasets: list):
        self.datasets = datasets
        self._iterators = None
        self._start_offset = 0
        self.epoch = 0

    def __iter__(self):
        """Iterate through all datasets in round-robin fashion for uniform sampling."""
        # Initialize iterators for all datasets
        self._iterators = [iter(dataset) for dataset in self.datasets]

        while True:
            samples_yielded_this_round = 0
            last_yield_idx = None

            # Round-robin through all datasets starting from start_offset
            for step in range(len(self._iterators)):
                idx = (self._start_offset + step) % len(self._iterators)
                iterator = self._iterators[idx]
                try:
                    sample = next(iterator)
                    samples_yielded_this_round += 1
                    last_yield_idx = idx
                    yield sample
                except StopIteration:
                    # This dataset is exhausted, continue with others
                    continue

            # If no samples were yielded in this round, all datasets are exhausted
            if samples_yielded_this_round == 0:
                break
            # Start next round from the dataset following the last successful yield
            if last_yield_idx is not None:
                self._start_offset = (last_yield_idx + 1) % len(self._iterators)

    def take(self, n: int):
        """Take n samples, distributing them across individual datasets."""
        samples_per_dataset = n // len(self.datasets)
        remainder = n % len(self.datasets)

        taken_datasets = []
        for i, dataset in enumerate(self.datasets):
            # Distribute samples as evenly as possible
            take_count = samples_per_dataset + (1 if i < remainder else 0)
            taken_datasets.append(dataset.take(take_count))

        return UniformSamplingDataset(taken_datasets)

    def skip(self, n: int):
        """Skip n samples, distributing them across individual datasets."""
        samples_per_dataset = n // len(self.datasets)
        remainder = n % len(self.datasets)

        skipped_datasets = []
        for i, dataset in enumerate(self.datasets):
            # Distribute skips as evenly as possible
            skip_count = samples_per_dataset + (1 if i < remainder else 0)
            skipped_datasets.append(dataset.skip(skip_count))

        return UniformSamplingDataset(skipped_datasets)

    def shuffle(self, seed: int | None = None):
        """Shuffle each individual dataset."""
        shuffled_datasets = [dataset.shuffle(seed=seed) for dataset in self.datasets]
        return UniformSamplingDataset(shuffled_datasets)

    # Optional epoch interface for compatibility with re-loop logic
    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)
        for ds in self.datasets:
            if hasattr(ds, "set_epoch") and hasattr(ds, "epoch"):
                try:
                    ds.set_epoch(epoch)
                except Exception:
                    pass

    # Stateful API for checkpointing
    def state_dict(self):
        datasets_state = []
        for ds in self.datasets:
            if hasattr(ds, "state_dict") and callable(
                getattr(ds, "state_dict")  # noqa: B009
            ):
                try:
                    datasets_state.append(ds.state_dict())
                except Exception:
                    datasets_state.append(None)
            else:
                datasets_state.append(None)
        return {
            "start_offset": self._start_offset,
            "epoch": self.epoch,
            "datasets": datasets_state,
        }

    def load_state_dict(self, state):
        if not state:
            return
        self._start_offset = int(state.get("start_offset", 0))
        self.epoch = int(state.get("epoch", 0))
        ds_states = state.get("datasets", [])
        for ds, ds_state in zip(self.datasets, ds_states):
            if ds_state is None:
                continue
            if hasattr(ds, "load_state_dict") and callable(
                getattr(ds, "load_state_dict")  # noqa: B009
            ):
                try:
                    ds.load_state_dict(ds_state)
                except Exception:
                    pass


def _is_hf_repo_id(path: str) -> bool:
    # Heuristic: hub datasets have a single slash `org/name` and usually don't exist locally
    return "/" in path and not os.path.exists(path)


def _infer_split_from_dataset_name(dataset_name: str) -> str:
    name = dataset_name.lower()
    if "validation" in name or name.endswith("_val") or name.endswith("-val"):
        return "validation"
    return "train"


def _create_streaming_dataset_for_file(file_uri: str, token: Optional[str]):
    lower = file_uri.lower()
    if lower.endswith(".parquet"):
        return load_dataset(
            "parquet", data_files=[file_uri], streaming=True, token=token, split="train"
        )
    # json and jsonl variants, optionally compressed
    if (
        lower.endswith(".json")
        or lower.endswith(".jsonl")
        or lower.endswith(".json.gz")
        or lower.endswith(".jsonl.gz")
        or lower.endswith(".jsonl.zst")
        or lower.endswith(".jsonl.zstd")
        or lower.endswith(".json.zst")
        or lower.endswith(".json.zstd")
    ):
        return load_dataset(
            "json", data_files=[file_uri], streaming=True, token=token, split="train"
        )
    return None


def create_single_shard_dataset_for_repo(
    repo_id: str,
    shard_index: int,
    num_shards: int,
    token: Optional[str],
    split: str = "train",
    seed: int = 42,
    max_files_per_shard: int = 25,
):
    if shard_index >= num_shards:
        raise ValueError(
            f"shard_index ({shard_index}) must be less than num_shards ({num_shards})"
        )

    api = HfApi()
    repo_info = api.repo_info(repo_id, repo_type="dataset", token=token)

    # Gather candidate files
    all_files: list[str] = []
    for sibling in repo_info.siblings:
        filename = sibling.rfilename
        lower = filename.lower()
        if (
            lower.endswith(".parquet")
            or lower.endswith(".json")
            or lower.endswith(".jsonl")
            or lower.endswith(".json.gz")
            or lower.endswith(".jsonl.gz")
            or lower.endswith(".jsonl.zst")
            or lower.endswith(".jsonl.zstd")
            or lower.endswith(".json.zst")
            or lower.endswith(".json.zstd")
        ):
            all_files.append(filename)

    # Optional split filtering by filename hints; fallback to all if filter is too aggressive
    def file_matches_split(name: str, split_name: str) -> bool:
        lname = name.lower()
        if split_name == "validation":
            return ("validation" in lname) or ("valid" in lname) or ("val" in lname)
        # train: exclude obvious non-train splits
        return not (
            "validation" in lname
            or "valid" in lname
            or "val" in lname
            or "test" in lname
            or "sample" in lname
        )

    candidate_files = [f for f in all_files if file_matches_split(f, split)]
    if not candidate_files:
        candidate_files = all_files

    # Build HF hub uris
    candidate_uris = [f"hf://datasets/{repo_id}/{f}" for f in candidate_files]

    # Stable shuffle, then shard
    random.seed(seed)
    random.shuffle(candidate_uris)

    files_per_shard = len(candidate_uris) // num_shards
    remainder = len(candidate_uris) % num_shards
    start_idx = shard_index * files_per_shard + min(shard_index, remainder)
    current_shard_size = files_per_shard + (1 if shard_index < remainder else 0)
    end_idx = start_idx + current_shard_size
    shard_files = candidate_uris[start_idx:end_idx][:max_files_per_shard]

    logger.info(
        f"Repo {repo_id} split={split} shard {shard_index}/{num_shards}: {len(shard_files)} files"
    )
    if shard_files:
        logger.info(f"Sample files: {shard_files[:3]}...")

    individual_datasets = []
    for uri in shard_files:
        ds = _create_streaming_dataset_for_file(uri, token=token)
        if ds is not None:
            individual_datasets.append(ds)
    if not individual_datasets:
        # Fallback to monolithic streaming load if file inference failed
        logger.warning(
            f"Falling back to load_dataset({repo_id}, split={split}, streaming=True)"
        )
        return load_dataset(repo_id, split=split, streaming=True, token=token)

    logger.info(
        f"Created {len(individual_datasets)} individual streaming datasets for uniform sampling"
    )
    return UniformSamplingDataset(individual_datasets)


@dataclass
class DatasetConfig:
    path: str
    loader: Callable
    text_processor: Callable


# Add your dataset here here - more information at docs/datasets.md
DATASETS = {
    "dclm": DatasetConfig(
        path="mlfoundations/dclm-baseline-1.0",
        loader=partial(_load_dclm_dataset, split="train"),
        text_processor=_process_dclm_text,
    ),
    "c4_test": DatasetConfig(
        path="tests/assets/c4_test",
        loader=lambda path: load_dataset(path, split="train"),
        text_processor=_process_c4_text,
    ),
    "c4_validation": DatasetConfig(
        path="allenai/c4",
        loader=partial(_load_c4_dataset, split="validation"),
        text_processor=_process_c4_text,
    ),
    "fineweb": DatasetConfig(
        path="HuggingFaceFW/fineweb",
        loader=_load_fineweb_dataset,
        text_processor=_process_fineweb_text,
    ),

    "olmo": DatasetConfig(
        path="HuggingFaceFW/fineweb",
        loader=_load_olmo_dataset,
        text_processor=_process_olmo_text,
    ),
    "fineweb-edu": DatasetConfig(
        path="HuggingFaceFW/fineweb-edu",
        loader=_load_fineweb_edu_dataset,
        text_processor=_process_fineweb_text,
    ),
}


def _validate_dataset(
    dataset_name: str, dataset_path: str | None = None
) -> tuple[str, Callable, Callable]:
    """Validate dataset name and path."""
    if dataset_name not in DATASETS:
        raise ValueError(
            f"Dataset {dataset_name} is not supported. "
            f"Supported datasets are: {list(DATASETS.keys())}"
        )

    config = DATASETS[dataset_name]
    path = dataset_path or config.path
    logger.info(f"Preparing {dataset_name} dataset from {path}")
    return path, config.loader, config.text_processor


class HuggingFaceDataset(IterableDataset, Stateful):
    def __init__(
        self,
        dataset_name: str,
        dataset_path: str | None,
        tokenizer: BaseTokenizer,
        seq_len: int = 2048,
        dp_rank: int = 0,
        dp_world_size: int = 1,
        infinite: bool = False,
    ) -> None:
        # Force lowercase for consistent comparison
        dataset_name = dataset_name.lower()

        path, dataset_loader, text_processor = _validate_dataset(
            dataset_name, dataset_path
        )
        split = _infer_split_from_dataset_name(dataset_name)
        token = os.getenv("HF_TOKEN")
        # If path looks like an HF hub dataset repo, build a file-sharded, round-robin dataset
        if _is_hf_repo_id(path):
            try:
                ds = create_single_shard_dataset_for_repo(
                    repo_id=path,
                    shard_index=dp_rank,
                    num_shards=dp_world_size,
                    token=token,
                    split=split,
                )
            except Exception as e:
                logger.warning(
                    f"Falling back to default loader for {path} due to error: {e}"
                )
                ds = dataset_loader(path)
        else:
            ds = dataset_loader(path)

        self.dataset_name = dataset_name
        # For custom sharded datasets created above, we avoid resharding
        self._data = (
            ds
            if _is_hf_repo_id(path)
            else split_dataset_by_node(ds, dp_rank, dp_world_size)
        )
        self._tokenizer = tokenizer
        self.seq_len = seq_len
        self.infinite = infinite
        self._text_processor = text_processor

        # Variables for checkpointing
        self._sample_idx = 0
        self._token_buffer: list[int] = []

    def _get_data_iter(self):
        # For map-style datasets, resume by skipping to the correct index
        # For iterable-style datasets, the underlying iterator already points to the correct index
        if isinstance(self._data, Dataset):
            if self._sample_idx == len(self._data):
                return iter([])
            else:
                return iter(self._data.skip(self._sample_idx))

        return iter(self._data)

    def __iter__(self):
        max_buffer_token_len = 1 + self.seq_len

        while True:
            for sample in self._get_data_iter():
                # Use the dataset-specific text processor
                sample_text = self._text_processor(sample)
                sample_tokens = self._tokenizer.encode(
                    sample_text, add_bos=True, add_eos=True
                )
                self._token_buffer.extend(sample_tokens)
                self._sample_idx += 1

                while len(self._token_buffer) >= max_buffer_token_len:
                    x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
                    # update tokens to the remaining tokens
                    self._token_buffer = self._token_buffer[max_buffer_token_len:]
                    input = x[:-1]
                    label = x[1:]
                    yield {"input": input}, label

            if not self.infinite:
                logger.warning(f"Dataset {self.dataset_name} has run out of data")
                break
            else:
                # Reset offset for the next iteration
                self._sample_idx = 0
                logger.warning(f"Dataset {self.dataset_name} is being re-looped")
                # Ensures re-looping a dataset loaded from a checkpoint works correctly
                if not isinstance(self._data, Dataset):
                    if hasattr(self._data, "set_epoch") and hasattr(
                        self._data, "epoch"
                    ):
                        self._data.set_epoch(self._data.epoch + 1)

    def load_state_dict(self, state_dict):
        self._token_buffer = state_dict["token_buffer"]

        if isinstance(self._data, Dataset):
            self._sample_idx = state_dict["sample_idx"]
        else:
            assert "data" in state_dict
            self._data.load_state_dict(state_dict["data"])

    def state_dict(self):
        _state_dict = {"token_buffer": self._token_buffer}

        if isinstance(self._data, Dataset):
            _state_dict["sample_idx"] = self._sample_idx
        else:
            # Save the iterable dataset's state to later efficiently resume from it
            # https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration
            _state_dict["data"] = self._data.state_dict()

        return _state_dict


def build_hf_dataloader(
    dp_world_size: int,
    dp_rank: int,
    tokenizer: BaseTokenizer,
    job_config: JobConfig,
    infinite: bool = True,
) -> ParallelAwareDataloader:
    """Build a data loader for HuggingFace datasets."""
    dataset_name = job_config.training.dataset
    dataset_path = job_config.training.dataset_path
    batch_size = job_config.training.local_batch_size
    seq_len = job_config.training.seq_len

    hf_ds = HuggingFaceDataset(
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        tokenizer=tokenizer,
        seq_len=seq_len,
        dp_rank=dp_rank,
        dp_world_size=dp_world_size,
        infinite=infinite,
    )

    return ParallelAwareDataloader(
        dataset=hf_ds,
        dp_rank=dp_rank,
        dp_world_size=dp_world_size,
        batch_size=batch_size,
    )


def build_hf_validation_dataloader(
    dp_world_size: int,
    dp_rank: int,
    tokenizer: BaseTokenizer,
    job_config: JobConfig,
) -> ParallelAwareDataloader:
    """Build a validation data loader for HuggingFace datasets."""
    dataset_name = job_config.validation.dataset
    dataset_path = job_config.validation.dataset_path
    batch_size = job_config.validation.local_batch_size
    seq_len = job_config.validation.seq_len

    hf_ds = HuggingFaceDataset(
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        tokenizer=tokenizer,
        seq_len=seq_len,
        dp_rank=dp_rank,
        dp_world_size=dp_world_size,
        infinite=False,
    )

    return ParallelAwareDataloader(
        dataset=hf_ds,
        dp_rank=dp_rank,
        dp_world_size=dp_world_size,
        batch_size=batch_size,
    )
