from __future__ import annotations
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import List
import numpy as np
import numpy.typing as npt
import pandas as pd
from gluonts.model.forecast import QuantileForecast
from gluonts.time_feature import get_seasonality


@dataclass
class QuantileForecasts:
    """
    A type-safe wrapper for a list of quantile forecasts, stored as NumPy arrays.
    """

    values: np.ndarray  # [num_time_series, num_quantiles, prediction_length]
    start_dates: np.ndarray
    item_ids: np.ndarray
    freq: pd.DateOffset
    quantiles: List[str]

    @property
    def prediction_length(self) -> int:
        """
        Returns the prediction length of the quantile forecasts.
        """
        return self.values.shape[-1]

    @property
    def seasonality(self) -> int:
        """
        Returns the seasonality of the forecasts (i.e. how many steps to go back to arrive at the
        value of the previous period).
        """
        return get_seasonality(self.freq.freqstr)

    # ---------------------------------------------------------------------------------------------
    # DATA ACCESS

    def get(self, index: int) -> QuantileForecast:
        """
        Returns the quantile forecast at the specified index. This method should typically only be
        used for visualizing single forecasts.
        """
        return QuantileForecast(
            forecast_arrays=self.values[index],
            start_date=pd.Timestamp(self.start_dates[index], freq=self.freq),
            freq=self.freq.freqstr,
            item_id=self.item_ids[index],
            forecast_keys=self.quantiles,
        )

    @property
    def median(self) -> np.ndarray:
        """
        Returns the median forecasts for all time series. NumPy array of shape [N, T] (N: number
        of forecasts, T: forecast horizon).
        """
        i = self.quantiles.index("0.5")
        return self.values[:, i]

    def __len__(self) -> int:
        return self.values.shape[0]

    def __getitem__(self, index: npt.ArrayLike) -> QuantileForecasts:
        return QuantileForecasts(
            values=self.values[index],
            start_dates=self.start_dates[index],
            item_ids=self.item_ids[index],
            freq=self.freq,
            quantiles=self.quantiles,
        )

    # ---------------------------------------------------------------------------------------------
    # STORAGE

    @classmethod
    def load(cls, path: Path) -> QuantileForecasts:
        """
        Loads the quantile forecasts from the specified path.

        Args:
            path: The path from where to load the forecasts.
        """
        with path.open("rb") as f:
            content = np.load(f, allow_pickle=True).item()
        return QuantileForecasts(**content)

    def save(self, path: Path) -> None:
        """
        Saves the forecasts to the specified path.

        Args:
            path: The path of the file where to save the forecasts to.
        """
        with path.open("wb+") as f:
            np.save(f, asdict(self))
