from typing import List, Optional

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

from data.preprocessing import Pipeline


class MyDataset(Dataset):
    def __init__(
        self, file_path_data: str, feature_names: List[str], seq_len: int, stride: int, pipeline: Optional[Pipeline]
    ) -> None:
        super().__init__()
        self.feature_names = feature_names

        self.seq_len = seq_len
        self.stride = stride

        df: pd.DataFrame = pd.read_csv(file_path_data, index_col=0)
        self.data_unprocessed: np.ndarray = df[feature_names].to_numpy()

        self.data: np.ndarray = pipeline.preprocess(self.data_unprocessed) if pipeline else self.data_unprocessed
        # self.data.shape = [time_series_len, n_features]

    def __getitem__(self, index) -> dict[str, torch.Tensor]:
        x_0_slice = slice(self.stride * index, self.stride * index + self.seq_len)
        x_0 = torch.as_tensor(self.data[x_0_slice], dtype=torch.float)
        return dict(x_0=x_0)

    def __len__(self) -> int:
        return ((len(self.data) - self.seq_len) // self.stride) + 1

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}"
            f"(len={len(self)},"
            f"seq_len={self.seq_len},"
            f"stride={self.stride})"
        )
