from __future__ import annotations

import os
from typing import Optional

import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset


class SolarDataset(Dataset):

    def __init__(
        self,
        root_path: str,
        flag: str = 'train',
        size: Optional[list[int]] = None,
        features: str = 'M',
        data_path: str = 'solar_AL.txt',
        target: str = 'OT',
        scale: bool = True,
        timeenc: int = 0,
        freq: str = 't',
        seasonal_patterns: Optional[str] = None,
    ) -> None:
        if size is None:
            raise ValueError(
                'size=[seq_len, label_len, pred_len] must be provided.')

        self.seq_len, self.label_len, self.pred_len = size
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.scale = scale
        self.root_path = root_path
        self.data_path = data_path

        self._read_data()

    def _read_data(self) -> None:
        scaler = StandardScaler()
        raw = []
        file_path = os.path.join(self.root_path, self.data_path)
        with open(file_path, 'r', encoding='utf-8') as f:
            if file_path.endswith('.csv'):
                next(f)
            for line in f:
                values = [float(x) for x in line.strip('\n').split(',')]
                raw.append(values)
        df = pd.DataFrame(np.array(raw))

        n_total = len(df)
        n_train = int(n_total * 0.7)
        n_test = int(n_total * 0.2)
        n_val = n_total - n_train - n_test

        borders_start = [0, n_train - self.seq_len,
                         n_total - n_test - self.seq_len]
        borders_end = [n_train, n_train + n_val, n_total]

        start = borders_start[self.set_type]
        end = borders_end[self.set_type]

        data = df.values
        if self.scale:
            scaler.fit(data[borders_start[0]:borders_end[0]])
            data = scaler.transform(data)

        self.data_x = data[start:end]
        self.data_y = data[start:end]

    def __getitem__(self, index: int):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = torch.zeros((seq_x.shape[0], 1), dtype=torch.float32)
        seq_y_mark = torch.zeros((seq_y.shape[0], 1), dtype=torch.float32)
        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self) -> int:
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        raise NotImplementedError
