import bisect
from typing import List, Tuple

import torch
from torch.utils.data import Dataset as torch_dataset


class TemporalTensorDataset(torch_dataset[Tuple[torch.Tensor, ...]]):
    """Customized TensorDataset for loading data with fold-time and AR format.
    The key function is __getitem__. Since the fold-time/AR format considers the lags, e.g. data is
    [[0.1],[0.2],[0.3],[0.4],[0.5],[0.6],[0.7]].
    Assume this is from two time-series with index segmentation [(0,3),(4,6)] and lag = 2.
    Then, the corresponding fold-time/AR dataset has total length 3, where index 0 data is [[0.1],[0.2],[0.3]], index 1 is
    [[0.2],[0.3],[0.4]] and index 2 (second time-series) is [[0.5],[0.6],[0.7]].
    The implementation strategy for __getitem__ is to convert the index
    of fold-time/AR format back to the index of original data. E.g. index 2 (fold-time) maps back to index 4 (original data)
    """

    tensors: Tuple[torch.Tensor, ...]

    def __init__(
        self, *tensors: torch.Tensor, lag: int, is_autoregressive: bool, index_segmentation: List[Tuple[int, ...]]
    ) -> None:
        """
        Init method for the dataset.
        Args:
            *tensors: A tuple of tensors. All tensors should follow the temporal format.
            lag: user specified lags
            is_autoregressive: whether we use this dataset for Autoregressive temporal models.
            index_segmentation: the index segmentation of the original data.
        """
        self.lag = lag
        self.is_autoregressive = is_autoregressive
        self.index_segmentation = index_segmentation
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self._validate_dataset()
        self.lag_index_segmentation_end = self._build_new_segmentation_end_with_lag()

    def _validate_dataset(self) -> None:
        """
        This is to ensure that the minimum time-series length is larger than specified lag.
        """
        min_length = min(idxs[1] - idxs[0] + 1 for idxs in self.index_segmentation)
        assert (
            min_length > self.lag
        ), f"Minimum series length ({min_length}) must be higher than specified lag ({self.lag})"

    def _build_new_segmentation_end_with_lag(self) -> List[int]:
        """
        This is to adapt the original index segmentation to generate index segmentation for the fold-time/AR dataset
        and only keep the end segmentation. The new returned segmentation is mainly used for mapping the fold-time/AR index
        back to the index of original data.
        For example, if the original index segmentation is [(0,3),(4,6)],
        and the lag is 2; then the new index segmentation for fold-time/AR dataset is [(0,1), (2,2)]. Since we only keep
        the end index, so the new_seg is [1,2].
        """
        new_seg = []
        for series_idx, seg in enumerate(self.index_segmentation):
            seg_end = seg[1] - (series_idx + 1) * self.lag
            new_seg.append(int(seg_end))

        return new_seg

    def _find_series_number(self, index: int) -> int:
        """
        This is to search for how many time-series the fold-time/AR index has moved across.
        This is achieved by finding the first end index in the format-time format that is larger or equal to the input fold-time/AR index.
        E.g. if the end index segmentation for AR/Fold-time dataset is [1, 2], and the input fold-time/AR index is 1, then
        the fold-time/AR index only move across 0 time-series.
        Args:
            index: data index
        """

        return bisect.bisect_left(self.lag_index_segmentation_end, index)

    def __getitem__(self, index):
        """
        This is to get the data in fold-time/AR format with corresponding index. This is achieved by mapping the index (fold-time/AR) back to the
        original index. Then we can directly return the original data[index_orig:index_orig+lag+1,:].
        The intuition is that the differences between the fold-time/AR and original index happens when the fold-time/AR index
        moves across time-series. For example, if the original index segmentation is [(0,3),(4,6)], and the lag is 2;
        then fold-time/AR index 0 and 1 does not across the first time-series, so fold-time/AR index = original index.
        However, if the fold-time/AR index is 2, then it moves across the first time-series, so to map it back to the original index,
        we need to account for the additional lag created during the crossing. So original index = fold-time/AR index + lag*number of crossed time-series.
        In the above example, the index 2 in fold-time/AR dataset maps back to index 4 in the original data.
        The number of crossed time-series is found by _find_series_number, which requires the new index segmentation generated by _build_new_segmentation_end_with_lag.
        Args:
            index: The index in fold-time/AR dataset.
        """
        index_orig = int(index + self._find_series_number(index) * self.lag)
        if self.is_autoregressive:
            # Return the data with shape [lag+1, node]
            return tuple(tensor[index_orig : index_orig + self.lag + 1, :] for tensor in self.tensors)
        else:
            # Return the data with shape [node*(lag+1)]
            return tuple(tensor[index_orig : index_orig + self.lag + 1, :].flatten() for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].shape[0] - len(self.index_segmentation) * self.lag
