# Copyright authors of TSPulse

from torch.utils.data import DataLoader, Dataset
from .dset import *


class DataLoaders:
    def __init__(
        self,
        datasetCls,
        dataset_kwargs: dict,
        batch_size: int,
        workers: int = 0,
        collate_fn=None,
        shuffle_train=True,
        shuffle_val=False,
    ):
        super().__init__()
        self.datasetCls = datasetCls
        self.batch_size = batch_size

        if "split" in dataset_kwargs.keys():
            del dataset_kwargs["split"]
        self.dataset_kwargs = dataset_kwargs
        self.workers = workers
        self.collate_fn = collate_fn
        self.shuffle_train, self.shuffle_val = shuffle_train, shuffle_val

        self.train = self.train_dataloader()
        self.valid = self.val_dataloader()
        self.test = self.test_dataloader()

    def train_dataloader(self):
        return self._make_dloader("train", shuffle=self.shuffle_train)

    def val_dataloader(self):
        return self._make_dloader("val", shuffle=self.shuffle_val)

    def test_dataloader(self):
        return self._make_dloader("test", shuffle=False)

    def _make_dloader(self, split, shuffle=False):
        dataset = self.datasetCls(**self.dataset_kwargs, split=split)
        if len(dataset) == 0:
            return None
        return DataLoader(
            dataset,
            shuffle=shuffle,
            batch_size=self.batch_size,
            num_workers=self.workers,
            collate_fn=self.collate_fn,
        )

    @classmethod
    def add_cli(self, parser):
        parser.add_argument("--batch_size", type=int, default=128)
        parser.add_argument(
            "--workers",
            type=int,
            default=6,
            help="number of parallel workers for pytorch dataloader",
        )

    def add_dl(self, test_data, batch_size=None, **kwargs):
        # check of test_data is already a DataLoader
        if isinstance(test_data, DataLoader):
            return test_data

        # get batch_size if not defined
        if batch_size is None:
            batch_size = self.batch_size
        # check if test_data is Dataset, if not, wrap Dataset
        if not isinstance(test_data, Dataset):
            test_data = self.train.dataset.new(test_data)

            # create a new DataLoader from Dataset
        test_data = self.train.new(test_data, batch_size, **kwargs)
        return test_data


##### ----------------------------------------------------------------------------------------- ######


def get_dls(params):
    if hasattr(params, "batch_size") and params.batch_size is not None:
        batch_size = params.batch_size
    else:
        batch_size = 8

    if hasattr(params, "num_workers") and params.num_workers is not None:
        workers = params.num_workers
    else:
        workers = 1

    if params.dset == "etth1":
        root_path = os.path.join(params.data_mount_path, "ETT-small")

        size = [params.context_points, 0, params.target_points]
        dls = DataLoaders(
            datasetCls=Dataset_ETT_hour,
            dataset_kwargs={
                "root_path": root_path,
                "data_path": "ETTh1.csv",
                "features": params.features,
                "scale": params.scale,
                "size": size,
            },
            batch_size=batch_size,
            workers=workers,
        )

    elif params.dset == "etth2":
        root_path = os.path.join(params.data_mount_path, "ETT-small")

        size = [params.context_points, 0, params.target_points]
        dls = DataLoaders(
            datasetCls=Dataset_ETT_hour,
            dataset_kwargs={
                "root_path": root_path,
                "data_path": "ETTh2.csv",
                "features": params.features,
                "scale": True,
                "size": size,
            },
            batch_size=batch_size,
            workers=workers,
        )

    elif params.dset == "ettm1":
        root_path = os.path.join(params.data_mount_path, "ETT-small")

        size = [params.context_points, 0, params.target_points]
        dls = DataLoaders(
            datasetCls=Dataset_ETT_minute,
            dataset_kwargs={
                "root_path": root_path,
                "data_path": "ETTm1.csv",
                "features": params.features,
                "scale": True,
                "size": size,
            },
            batch_size=batch_size,
            workers=workers,
        )

    elif params.dset == "ettm2":
        root_path = os.path.join(params.data_mount_path, "ETT-small")

        size = [params.context_points, 0, params.target_points]
        dls = DataLoaders(
            datasetCls=Dataset_ETT_minute,
            dataset_kwargs={
                "root_path": root_path,
                "data_path": "ETTm2.csv",
                "features": params.features,
                "scale": True,
                "size": size,
            },
            batch_size=batch_size,
            workers=workers,
        )

    elif params.dset == "weather":
        root_path = os.path.join(params.data_mount_path, "weather")

        size = [params.context_points, 0, params.target_points]
        dls = DataLoaders(
            datasetCls=Dataset_Custom,
            dataset_kwargs={
                "root_path": root_path,
                "data_path": "weather.csv",
                "features": params.features,
                "scale": True,
                "size": size,
                "freq": "10_minutes",
            },
            batch_size=batch_size,
            workers=workers,
        )

    elif params.dset == "electricity":
        root_path = os.path.join(params.data_mount_path, "electricity")

        size = [params.context_points, 0, params.target_points]
        dls = DataLoaders(
            datasetCls=Dataset_Custom,
            dataset_kwargs={
                "root_path": root_path,
                "data_path": "electricity.csv",
                "features": params.features,
                "scale": True,
                "size": size,
                "freq": "hourly",
            },
            batch_size=batch_size,
            workers=workers,
        )

    # dataset is assume to have dimension len x nvars
    ret_obj = dls.train.dataset[0]
    if isinstance(ret_obj, dict):
        first_obj = ret_obj["past_values"]
    else:
        first_obj = ret_obj[0]

    if isinstance(first_obj, dict):
        first_obj = first_obj["context"]

    dls.vars, dls.len = first_obj.shape[1], params.context_points
    # dls.c = dls.train.dataset[0][1].shape[0]
    dls.c = first_obj[1].shape[0]

    return dls
