import glob, copy, os, math
from typing import List, Dict, Optional, Union
from collections.abc import Sequence
from multiprocess import Pool, RLock
from tqdm.auto import tqdm
import torch
import datasets
import pytorch_lightning as pl


class HFDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset: Union[str, dict, list[Union[str, dict]]],
        # info(s) to load huggingface dataset(s)
        # if str: datasets.load_dataset(dataset)
        # if dict: datasets.load_dataset(**dataset)
        # if list: load and concatenate all datasets specified
        cache_name_template: str,
        # pattern name of preprocessed arrow file to find or save
        cache_dir: Optional[str] = None,
        # datasets.load_dataset(cache_dir=cache_dir)
        using_features: Optional[Union[list[str], dict[str, list[str]]]] = None,
        # name of columns will be used (converted into tensor) from preprocessed dataset,
        # passed to Dataset.set_format. If None, all data columns will be used.
        tfm_dataset_cls: Optional[torch.utils.data.Dataset] = None,
        # User defined dataset class which wraps the preprocessed dataset
    ):
        super().__init__()
        datasets = dataset if isinstance(dataset, Sequence) else [dataset]
        self.raw_datasets_kwargs = []
        for dset in datasets:
            kwargs = {"path": dset} if isinstance(dset, str) else dset
            if cache_dir:
                kwargs["cache_dir"] = cache_dir
            self.raw_datasets_kwargs.append(kwargs)
        assert "{split}" in cache_name_template
        assert cache_name_template.endswith(".arrow")
        self.cache_name_template = cache_name_template
        self.using_features = using_features
        self.tfm_dataset_cls = tfm_dataset_cls

    def prepare_data(self):
        for load_dataset_kwargs in self.raw_datasets_kwargs:

            # Save loading time if data of all splits is already prepared
            data_dir, caches = self.find_caches(load_dataset_kwargs)
            if caches:  # cached processed datasets exist
                continue

            # Load raw dataset dict
            raw_dsets = datasets.load_dataset(**load_dataset_kwargs)
            if isinstance(raw_dsets, datasets.Dataset):
                split = load_dataset_kwargs["split"]
                raw_dsets = datasets.DatasetDict({f"{split}": raw_dsets})

            # Preprocess the dataset and cache the results
            for split, dset in raw_dsets.items():
                self._preprocess(
                    dataset=dset,
                    split=split,
                    cache_file_path=os.path.join(
                        data_dir, self.cache_name_template.format(split=split)
                    ),
                )

    def setup(self, stage):
        # Load cached preprocessed datasets
        processed_dsetdicts = []
        for load_dataset_kwargs in self.raw_datasets_kwargs:
            _, cache_files_dict = self.find_caches(load_dataset_kwargs)
            processed_dsetdict = self.load_caches(cache_files_dict)
            processed_dsetdicts.append(processed_dsetdict)

        # Concatenate loaded datasets along the same split
        self.datasets = self.concatenate_dsetdicts(processed_dsetdicts)
        if isinstance(self.using_features, dict):
            for split, dset in self.datasets.items():
                dset.set_format("torch", columns=self.using_features[split])
        else:
            self.datasets.set_format("torch", columns=self.using_features)

        # (Optional) Add dynamic transformation by wrap dataset with torch dataset
        if self.tfm_dataset_cls:
            self.datasets = {
                split: self.tfm_dataset_cls(dset)
                for split, dset in self.datasets.items()
            }

    def train_dataloader(self):
        shuffle = True
        if (
            "PL_FAULT_TOLERANT_TRAINING" in os.environ
            and os.environ["PL_FAULT_TOLERANT_TRAINING"] == "1"
        ):
            shuffle = False
        return self._dataloader(self.datasets["train"], shuffle=shuffle)

    def val_dataloader(self):
        return self._dataloader(self.datasets["validation"], shuffle=False)

    def test_dataloader(self):
        return self._dataloader(self.datasets["test"], shuffle=False)

    # #####################
    #   Abstract Methods
    # #####################

    def _preprocess(self, dataset: datasets.Dataset, split: str, cache_file_path: str):
        # Preprocess dataset if haven't and cache the results
        raise NotImplementedError

    def _collate(self, samples: List[Dict]):
        raise NotImplementedError

    # #####################
    #   Helpers
    # #####################

    def find_caches(self, load_dataset_kwargs: dict):
        # Get data diretory where arrow files will be placed and splits of dataset
        load_builder_kwargs = copy.deepcopy(load_dataset_kwargs)
        load_builder_kwargs.pop(
            "split", None
        )  # load_dataset_builder don't receive "split"
        dataset_builder = datasets.load_dataset_builder(**load_builder_kwargs)
        splits = dataset_builder.info.splits  # dict
        if splits is None:
            # don't know how to make dataset_builder.info.splits correct when loading from data file
            splits = {"test": None}
        data_dir = dataset_builder.cache_dir

        # Find cache files of preprocessed datastes
        cache_files = {}
        for split in splits:
            cache_name = self.cache_name_template.format(split=split)
            # insert * before ".arrow" to capture shards suffix (e.g. 00000_of_00008) if multi processing used in preprocessing
            cache_name_pattern = cache_name.replace(".arrow", "*.arrow")
            _cache_files = glob.glob(os.path.join(data_dir, cache_name_pattern))
            if not _cache_files:  # not found
                return data_dir, False
            cache_files[split] = _cache_files
        return data_dir, cache_files

    def load_caches(
        self,
        cache_files,  # (dict of) cache_files_to_be_loaded_and_concatenated:list[str]
        max_num_workers=None,
    ):  # return (dict of) datset

        # Load cache files as a single dataset for each split indivisually
        if isinstance(cache_files, dict):
            return datasets.DatasetDict(
                {
                    split: self.load_caches(_cache_files)
                    for split, _cache_files in cache_files.items()
                }
            )

        # Load the single cache file as dataset directly
        if isinstance(cache_files, list) and len(cache_files) == 1:
            cache_files = cache_files[0]
        if isinstance(cache_files, str):
            cache_file = cache_files
            print(f"loading cache: {cache_file}", end="\t", flush=True)
            dset = datasets.Dataset.from_file(cache_file)
            print(f"loaded")
            return dset

        # Parallelly load cache files and concatenate them as a single dataset
        cache_files = sorted(cache_files)

        def _load_caches(files):
            dsets = [datasets.Dataset.from_file(f) for f in files]
            if len(dsets) == 1:
                return dsets[0]
            else:
                return datasets.concatenate_datasets(dsets)

        max_num_workers = max_num_workers or os.cpu_count()
        n = math.ceil(len(cache_files) / max_num_workers)
        parts = [cache_files[i : i + n] for i in range(0, len(cache_files), n)]
        print(
            f"loading the following caches with {len(parts)} parallel processes: {cache_files}",
            flush=True,
        )
        with Pool(len(parts), initargs=(RLock(),), initializer=tqdm.set_lock) as pool:
            results = [pool.apply_async(_load_caches, (files,)) for files in parts]
            dset = datasets.concatenate_datasets([r.get() for r in results])
            print("All caches listed above are loaded.")
        return dset

    def concatenate_dsetdicts(self, dset_dicts):
        if len(dset_dicts) == 1:
            return dset_dicts[0]

        flat_dset_dict = {}
        for dset_dict in dset_dicts:
            for split, dset in dset_dict.items():
                flat_dset_dict.setdefault(split, [])
                flat_dset_dict[split].append(dset)
        result_dset_dict = {}
        for split, dsets in flat_dset_dict.items():
            result_dset_dict[split] = datasets.concatenate_datasets(dsets)
        return datasets.DatasetDict(result_dset_dict)

    def _dataloader(self, dataset, pin_memory=True, **kwargs):
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            collate_fn=self._collate,
            pin_memory=pin_memory,
            num_workers=self.config.dataloader_num_workers,
            **kwargs,
        )
