from typing import Optional

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler


class ForecastingDataset:
    def __init__(
        self,
        file_path: str = "../data/ETTh1.csv",
        forecast_horizon: Optional[int] = 96,
        data_split: str = "train",
        data_stride_len: int = 1,
        task_name: str = "forecasting",
        random_seed: int = 42,
        only_OT: bool = False,
    ):
        """
        Parameters
        ----------
        forecast_horizon : int
            Length of the prediction sequence.
        data_split : str
            Split of the dataset, 'train', 'val', or 'test'.
        data_stride_len : int
            Stride length when generating consecutive
            time series windows.
        task_name : str
            The task that the dataset is used for. One of
            'forecasting', or  'imputation'.
        random_seed : int
            Random seed for reproducibility.
        """

        self.seq_len = 512
        self.forecast_horizon = forecast_horizon
        self.full_file_path_and_name = file_path
        assert data_split in ['train', 'val', 'test']
        self.data_split = data_split
        self.data_stride_len = data_stride_len
        self.task_name = task_name
        self.random_seed = random_seed
        self.only_OT = only_OT

        # Read data
        self._read_data()

    def _get_borders(
                self, 
                totoal_length: int, 
                train_ratio: float = 0.6,
                val_ratio: float = 0.1,
                test_ratio: float = 0.3):
        assert train_ratio + val_ratio + test_ratio == 1.0, "The sum of the ratios must be 1."
        # n_train = 12 * 30 * 24
        # n_val = 4 * 30 * 24
        # n_test = 4 * 30 * 24
        n_train = int(totoal_length * train_ratio)
        n_test = int(totoal_length * test_ratio)
        n_val = totoal_length - n_train - n_test

        train = slice(0, n_train)
        val = slice(n_train - self.seq_len, n_train + n_val)
        test = slice(n_train + n_val - self.seq_len, totoal_length)

        return train, val, test

    def _read_data(self):
        self.scaler = StandardScaler()
        df = pd.read_csv(self.full_file_path_and_name)
        self.length_timeseries_original = df.shape[0]
        self.n_channels = df.shape[1] - 1

        df.drop(columns=["date"], inplace=True)
        df = df.infer_objects(copy=False).interpolate(method="cubic")

        if self.only_OT:
            df = df[['OT']]
            
        data_splits = self._get_borders(totoal_length=len(df))

        train_data = df[data_splits[0]]
        self.scaler.fit(train_data.values)
        df = self.scaler.transform(df.values)

        if self.data_split == "train":
            self.data = df[data_splits[0], :]
        elif self.data_split == "val":
            self.data = df[data_splits[1], :]
        elif self.data_split == "test":
            self.data = df[data_splits[2], :]
        else:
            raise ValueError("{self.data_split} is not a valid data split. Choose from 'train', 'val', or 'test'.")

        self.length_timeseries = self.data.shape[0]
        self.n_channels = self.data.shape[1]

    def __getitem__(self, index):
        seq_start = self.data_stride_len * index
        seq_end = seq_start + self.seq_len
        input_mask = np.ones(self.seq_len)

        if self.task_name == "forecasting":
            pred_end = seq_end + self.forecast_horizon

            if pred_end > self.length_timeseries:
                pred_end = self.length_timeseries
                seq_end = seq_end - self.forecast_horizon
                seq_start = seq_end - self.seq_len

            timeseries = self.data[seq_start:seq_end, :].T
            forecast = self.data[seq_end:pred_end, :].T

            return timeseries, forecast, input_mask

        elif self.task_name == "imputation":
            if seq_end > self.length_timeseries:
                seq_end = self.length_timeseries
                seq_start = seq_end - self.seq_len

            timeseries = self.data[seq_start:seq_end, :].T

            return timeseries, input_mask

    def __len__(self):
        if self.task_name == "imputation":
            return (self.length_timeseries - self.seq_len) // self.data_stride_len + 1
        elif self.task_name == "forecasting":
            return (
                self.length_timeseries - self.seq_len - self.forecast_horizon
            ) // self.data_stride_len + 1
