# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections.abc import AsyncIterable, Iterable
from typing import Any, Union

from datasets import load_dataset
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from torch.utils.data import Dataset

from ..config.data_args import DataArguments
from ..extras.types import DatasetInfo, HFDataset, Sample


class DataEngine(Dataset):
    """Data engine."""

    def __init__(self, data_args: DataArguments) -> None:
        self.args = data_args
        """Data arguments."""
        self.datasets: dict[str, HFDataset] = {}
        """Dict of (dataset_name, dataset)"""
        self.dataset_infos: dict[str, DatasetInfo] = {}
        """Dict of (dataset_name, dataset_info)"""
        self.data_index: list[tuple[str, int]] = []
        """List of (dataset_name, sample_index)"""
        self.streaming: bool = False
        """Whether dataset is streaming."""
        self.get_dataset_info()
        self.load_dataset()
        self.build_data_index()

    def get_dataset_info(self) -> None:
        """Get dataset info from data arguments."""
        if self.args.dataset.endswith(".yaml") and os.path.isfile(
            os.path.join(self.args.dataset_dir, self.args.dataset)
        ):  # local file
            self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
        elif self.args.dataset.endswith(".yaml"):  # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
            repo_id, filename = os.path.split(self.args.dataset)
            filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
            self.dataset_infos = OmegaConf.load(filepath)
        elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)):  # local file(s)
            self.dataset_infos = {"default": {"file_name": self.args.dataset}}
        else:  # hf hub dataset, e.g. llamafactory/v1-sft-demo
            self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}

    def load_dataset(self) -> None:
        """Load datasets according to dataset info."""
        for key, value in self.dataset_infos.items():
            split = value.get("split", "train")
            streaming = value.get("streaming", False)
            self.streaming |= streaming
            if "hf_hub_url" in value:
                self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
            else:  # data loader plugin
                from ..plugins.data_plugins.loader import DataLoaderPlugin

                self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)

    def build_data_index(self) -> None:
        """Build dataset index."""
        for dataset_name, dataset in self.datasets.items():
            size = self.dataset_infos[dataset_name].get("size")
            weight = self.dataset_infos[dataset_name].get("weight")
            if self.streaming:
                data_index = [(dataset_name, -1) for _ in range(1000)]
            else:
                data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]

            if size or weight:  # data index plugin
                from ..plugins.data_plugins.loader import DataIndexPlugin

                data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)

            self.data_index.extend(data_index)

    def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
        """Convert dataset sample.

        Args:
            raw_sample (dict[str, Any]): Raw dataset sample.
            dataset_name (str): Dataset name.

        Returns:
            Sample: Dataset sample.
        """
        converter = self.dataset_infos[dataset_name].get("converter")
        if converter is not None:
            from ..plugins.data_plugins.converter import get_converter

            return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
        else:
            return {"_dataset_name": dataset_name, **raw_sample}

    def __len__(self) -> int:
        """Get dataset length.

        Returns:
            int: Dataset length.
        """
        if self.streaming:
            return -1
        else:
            return len(self.data_index)

    def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
        """Get dataset item.

        Args:
            index (int): Dataset index.

        Returns:
            Sample: Dataset item.
        """
        if self.streaming:
            raise ValueError("Streaming dataset does not support index access.")

        if isinstance(index, int):
            dataset_name, sample_index = self.data_index[index]
            return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
        else:
            from ..plugins.data_plugins.loader import DataSelectorPlugin

            selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
            if isinstance(selected_index, list):
                return [
                    self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
                    for dataset_name, sample_index in selected_index
                ]
            else:
                dataset_name, sample_index = selected_index
                return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)

    def __iter__(self) -> Iterable:
        """Get dataset iterator.

        Returns:
            Iterable: Dataset iterator.
        """
        if self.streaming:
            pass
        else:
            # TODO: add shuffle here
            pass

        raise NotImplementedError()

    async def __aiter__(self) -> AsyncIterable:
        """Get dataset async iterator.

        Returns:
            AsyncIterable: Dataset async iterator.
        """
        if self.streaming:
            pass
        else:
            # TODO: add shuffle here
            pass

        raise NotImplementedError()


if __name__ == "__main__":
    from ..config.parser import get_args

    data_args, *_ = get_args()
    data_engine = DataEngine(data_args=data_args)
    print(data_engine[0])
