
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generator, Optional, Tuple

import pandas as pd
import torch.utils.data

from gluonts.dataset import Dataset, DataEntry
from gluonts.dataset.field_names import FieldName
from tqdm import tqdm


def periods_between(
        start: pd.Period,
        end: pd.Period,
) -> int:
    """
    Count how many periods fit between ``start`` and ``end`` (inclusive). The
    frequency is taken from ``start``.

    For example:

        >>> start = pd.Period("2021-01-01 00", freq="2H")
        >>> end = pd.Period("2021-01-01 11", "2H")
        >>> periods_between(start, end)
        6

        >>> start = pd.Period("2021-03-03 23:00", freq="30T")
        >>> end = pd.Period("2021-03-04 03:29", freq="30T")
        >>> periods_between(start, end)
        9
    """
    if start > end:
        return 0
    return ((end - start).n // start.freq.n) + 1


def to_positive_slice(slice_: slice, length: int) -> slice:
    """
    Return an equivalent slice with positive bounds, given the length of the
    sequence it will apply to.
    """
    start, stop = slice_.start, slice_.stop
    if start is not None and start < 0:
        start += length
        assert start >= 0
    if stop is not None and stop < 0:
        stop += length
        assert stop >= 0
    return slice(start, stop, slice_.step)


def to_integer_slice(slice_: slice, start: pd.Period) -> slice:
    """
    Returns an equivalent slice with integer bounds, given the start timestamp
    of the sequence it will apply to.
    """
    start_is_int = isinstance(slice_.start, (int, type(None)))
    stop_is_int = isinstance(slice_.stop, (int, type(None)))

    if start_is_int and stop_is_int:
        return slice_

    if isinstance(slice_.start, pd.Period):
        start_offset = (slice_.start - start).n
        assert start_offset >= 0
    elif start_is_int:
        start_offset = slice_.start
    else:
        raise ValueError(
            "Can only use None, int, or pd.Period for slicing, got type "
            f"{type(slice_.start)}"
        )

    if isinstance(slice_.stop, pd.Period):
        stop_offset = (slice_.stop - start).n + 1
        assert stop_offset >= 0
    elif stop_is_int:
        stop_offset = slice_.stop
    else:
        raise ValueError(
            "Can only use None, int, or pd.Period for slicing, got type "
            f"{type(slice_.stop)}"
        )

    return slice(start_offset, stop_offset)


def slice_data_entry(
        entry: DataEntry, slice_: slice, prediction_length: int = 0
) -> DataEntry:
    slice_ = to_positive_slice(
        to_integer_slice(slice_, entry[FieldName.START]),
        entry[FieldName.TARGET].shape[-1],
    )

    if slice_.stop is not None:
        slice_extended = slice(
            slice_.start, slice_.stop + prediction_length, slice_.step
        )
    else:
        slice_extended = slice_

    sliced_entry = dict(entry)

    if slice_.start is not None:
        offset = slice_.start
        if offset < 0:
            offset += entry["target"].shape[-1]
        sliced_entry[FieldName.START] += offset

    if len(sliced_entry[FieldName.TARGET].shape) == 1:
        sliced_entry[FieldName.TARGET] = sliced_entry[FieldName.TARGET][slice_]
    else:
        sliced_entry[FieldName.TARGET] = sliced_entry[FieldName.TARGET][
                                         :, slice_
                                         ]

    if FieldName.FEAT_DYNAMIC_REAL in sliced_entry:
        sliced_entry[FieldName.FEAT_DYNAMIC_REAL] = sliced_entry[
                                                        FieldName.FEAT_DYNAMIC_REAL
                                                    ][:, slice_extended]

    if FieldName.FEAT_DYNAMIC_CAT in sliced_entry:
        sliced_entry[FieldName.FEAT_DYNAMIC_CAT] = sliced_entry[
                                                       FieldName.FEAT_DYNAMIC_CAT
                                                   ][:, slice_extended]

    if FieldName.PAST_FEAT_DYNAMIC_REAL in sliced_entry:
        sliced_entry[FieldName.PAST_FEAT_DYNAMIC_REAL] = sliced_entry[
                                                             FieldName.PAST_FEAT_DYNAMIC_REAL
                                                         ][:, slice_]

    return sliced_entry


class AbstractBaseSplitter(ABC):
    """
    Base class for all other splitter.
    """

    @abstractmethod
    def training_entry(self, entry: DataEntry) -> DataEntry:
        pass

    @abstractmethod
    def test_pair(
            self, entry: DataEntry, prediction_length: int, offset: int = 0
    ) -> Tuple[DataEntry, DataEntry]:
        pass

    def split(
            self, dataset: Dataset
    ) -> Tuple["TrainingDataset", "TestTemplate"]:
        return (
            TrainingDataset(dataset=dataset, splitter=self),
            TestTemplate(dataset=dataset, splitter=self),
        )

    def generate_training_entries(
            self, dataset: Dataset
    ) -> Generator[DataEntry, None, None]:
        yield from map(self.training_entry, dataset)

    def generate_test_pairs(
            self,
            dataset: Dataset,
            prediction_length: int,
            windows: int = 1,
            distance: Optional[int] = None,
            max_history: Optional[int] = None,
    ) -> Generator[Tuple[DataEntry, DataEntry], None, None]:
        if distance is None:
            distance = prediction_length

        for entry in dataset:
            for window in range(windows):
                offset = window * distance
                test = self.test_pair(
                    entry, prediction_length=prediction_length, offset=offset
                )

                if max_history is not None:
                    yield slice_data_entry(
                        test[0], slice(-max_history, None)
                    ), test[1]
                else:
                    yield test[0], test[1]


@dataclass
class OffsetSplitter(AbstractBaseSplitter):
    """
    A splitter that slices training and test data based on a fixed integer
    offset.

    Parameters
    ----------
    offset
        Offset determining where the training data ends.
        A positive offset indicates how many observations since the start of
        each series should be in the training slice; a negative offset
        indicates how many observations before the end of each series should
        be excluded from the training slice.
    """

    offset: int

    def training_entry(self, entry: DataEntry) -> DataEntry:
        return slice_data_entry(entry, slice(None, self.offset))

    def test_pair(
            self, entry: DataEntry, prediction_length: int, offset: int = 0
    ) -> Tuple[DataEntry, DataEntry]:
        offset_ = self.offset + offset
        if self.offset < 0:
            offset_ += entry[FieldName.TARGET].shape[-1]
        assert (
                offset_ + prediction_length <= entry[FieldName.TARGET].shape[-1]
        ), "Not enough data to generate some of the windows; try splitting data at an earlier offset"

        if offset_ + prediction_length:
            input_slice = slice(None, offset_)
            label_slice = slice(offset_, offset_ + prediction_length)
        else:
            input_slice = slice(None, offset_)
            label_slice = slice(offset_, None)
        return (
            slice_data_entry(
                entry, input_slice, prediction_length=prediction_length
            ),
            slice_data_entry(
                entry, label_slice, prediction_length=prediction_length
            ),
        )


@dataclass
class DateSplitter(AbstractBaseSplitter):
    """
    A splitter that slices training and test data based on a ``pandas.Period``.

    Training entries obtained from this class will be limited to observations
    up to (including) the given ``date``.

    Parameters
    ----------
    date
        ``pandas.Period`` determining where the training data ends.
    """

    date: pd.Period

    def training_entry(self, entry: DataEntry) -> DataEntry:
        length = periods_between(entry["start"], self.date)
        return slice_data_entry(entry, slice(None, length))

    def test_pair(
            self, entry: DataEntry, prediction_length: int, offset: int = 0
    ) -> Tuple[DataEntry, DataEntry]:
        base = periods_between(entry["start"], self.date)
        input_slice = slice(None, base + offset)
        label_slice = slice(base + offset, base + offset + prediction_length)
        assert (
                label_slice.stop <= entry[FieldName.TARGET].shape[-1]
        ), "Not enough data to generate some of the windows; try splitting data at an earlier date"
        return (
            slice_data_entry(
                entry, input_slice, prediction_length=prediction_length
            ),
            slice_data_entry(
                entry, label_slice, prediction_length=prediction_length
            ),
        )


@dataclass
class TestData:
    """
    An iterable type used for wrapping test data.

    Elements of a ``TestData`` object are pairs ``(input, label)``, where
    ``input`` is input data for models, while ``label`` is the future
    ground truth that models are supposed to predict.

    Parameters
    ----------
    dataset:
        Whole dataset used for testing.
    splitter:
        A specific splitter that knows how to slices training and
        test data.
    prediction_length
        Length of the prediction interval in test data.
    windows
        Indicates how many test windows to generate for each original
        dataset entry.
    distance
        This is rather the difference between the start of each test
        window generated, for each of the original dataset entries.
    max_history
        If given, all entries in the *test*-set have a max-length of
        `max_history`. This can be used to produce smaller file-sizes.
    """

    dataset: Dataset
    splitter: AbstractBaseSplitter
    prediction_length: int
    windows: int = 1
    distance: Optional[int] = None
    max_history: Optional[int] = None

    def __iter__(self) -> Generator[Tuple[DataEntry, DataEntry], None, None]:
        yield from self.splitter.generate_test_pairs(
            dataset=self.dataset,
            prediction_length=self.prediction_length,
            windows=self.windows,
            distance=self.distance,
            max_history=self.max_history,
        )

    def __len__(self):
        return len(self.dataset) * self.windows

    @property
    def input(self) -> "InputDataset":
        return InputDataset(self)

    @property
    def label(self) -> "LabelDataset":
        return LabelDataset(self)


@dataclass
class InputDataset:
    test_data: TestData

    def __len__(self):
        return len(self.test_data)

    def __iter__(self):
        for input, _label in self.test_data:
            yield input


@dataclass
class LabelDataset:
    test_data: TestData

    def __len__(self):
        return len(self.test_data)

    def __iter__(self):
        for _input, label in self.test_data:
            yield label


@dataclass
class TestTemplate:
    """
    A class used for generating test data.

    Parameters
    ----------
    dataset:
        Whole dataset used for testing.
    splitter:
        A specific splitter that knows how to slices training and
        test data.
    """

    dataset: Dataset
    splitter: AbstractBaseSplitter

    def generate_instances(
            self,
            prediction_length: int,
            windows: int = 1,
            distance: Optional[int] = None,
            max_history: Optional[int] = None,
    ) -> TestData:
        """
        Generate an iterator of test dataset, which includes input part and
        label part.

        Parameters
        ----------
        prediction_length
            Length of the prediction interval in test data.
        windows
            Indicates how many test windows to generate for each original
            dataset entry.
        distance
            This is rather the difference between the start of each test
            window generated, for each of the original dataset entries.
        max_history
            If given, all entries in the *test*-set have a max-length of
            `max_history`. This can be used to produce smaller file-sizes.
        """
        return TestData(
            self.dataset,
            self.splitter,
            prediction_length,
            windows,
            distance,
            max_history,
        )


class TrainingDataset(torch.utils.data.Dataset):
    dataset: Dataset
    splitter: AbstractBaseSplitter

    def __init__(self, dataset: Dataset, splitter: AbstractBaseSplitter):
        super().__init__()
        self.dataset = dataset
        self.splitter = splitter
        self.splitted_data = [self.splitter.training_entry(entry) for entry in
                              tqdm(self.dataset, desc="Splitting Dataset")]

    # def __iter__(self) -> Generator[DataEntry, None, None]:
    #     return self.splitter.generate_training_entries(self.dataset)
    def __getitem__(self, index):
        return self.splitted_data[index]

    def __len__(self) -> int:
        return len(self.dataset)


def split(
        dataset: Dataset, *, offset: Optional[int] = None, date: pd.Period = None
) -> Tuple[TrainingDataset, TestTemplate]:
    assert (offset is None) != (
            date is None
    ), "You need to provide ``offset`` or ``date``, but not both."
    if offset is not None:
        return OffsetSplitter(offset).split(dataset)
    else:
        return DateSplitter(date).split(dataset)
