import pandas as pd
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler
from utils.timefeatures import time_features
import warnings

warnings.filterwarnings('ignore')

class Dataset_Custom(Dataset):
    def __init__(self, raw_data: pd.DataFrame, border1s: list, border2s: list, flag='train', 
                 size=None, features='S', target='OT', scale=True, timeenc=0, freq='h'):
        """
        最直接的数据加载方式
        :param flag:        训练集等的拆分
        :param size:        [seq_len, label_len, pred_len]
        :param features:    多/单变量预测多/单变量
        :param target:      在预测单变量任务时的预测目标
        :param scale:       是否进行归一化
        :param timeenc:     对时间点的嵌入编码方案
        :param freq:        时间点编码的粒度
        """
        # info
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val', 'pred']
        self.set_type = 0 if flag=='train' else 1 if flag=='val' else 2 if flag=='test' else 3
        # 以weather为例，是M为代表的多变量预测多变量
        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.__read_data__(raw_data, border1s, border2s)

    def __read_data__(self, df_raw, border1s, border2s):
        self.scaler = StandardScaler()
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            df_data = df_raw[df_raw.columns[1:]]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        # 对训练集的特征求均值和标准差，并应用到整个数据集上，得到Z-Score归一化结果。这是为了模拟真实环境，现实中我们无法得到测试数据的均值和方差
        if self.scale:
            if self.set_type==3:
                self.scaler.fit(df_data.values)
            else:
                train_data = df_data[border1s[0]:border2s[0]]
                self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        # 进行时间嵌入转换
        df_stamp = df_raw[['date']][border1:border2]
        df_stamp['date'] = pd.to_datetime(df_stamp.date)
        # 目前的数据集均遵循时间标记全部以字符串保存在第一列
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['date'], 1).values
        elif self.timeenc == 1:
            # 按照年、月、日、分钟粒度生成了日期的特征
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)
        if self.set_type==3:
            self.data_x = data[border1:border2-self.pred_len]
            self.data_y = data[border1:border2]
        else:
            self.data_x = data[border1:border2]
            self.data_y = data[border1:border2]
        # self.date = df_stamp['date'].apply(lambda x: float(x.timestamp())).values
        self.data_stamp = data_stamp

    # 这里定义了DataLoader在取数据时如何从当前这个DataSet获得数据。每次按索引返回一条数据，索引由batch提取器自动提供
    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]
        # seq_x_date = self.date[s_begin:s_end]
        # seq_y_date = self.date[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        if self.set_type==3:
            return len(self.data_x) - self.seq_len + 1
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)
