import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from utils.timefeatures import time_features
import warnings
from utils.constants import Constants

warnings.filterwarnings('ignore')


class Dataset_mekong_LamaH_for_gnn(Dataset):
    def __init__(self, args, root_path, flag='train', size=None,
                 features='M', data_path='ID_1.csv',
                 target='Value', scaler_dict=None, timeenc=0, freq='d'):
        # size [seq_len, label_len, pred_len]
        self.args = args
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scaler_dict = scaler_dict
        self.timeenc = timeenc
        self.freq = freq
        self.num_vali = args.num_vali
        self.num_test = args.num_test

        self.root_path = root_path
        self.data_path = data_path
        self.all_stations = Constants(args).all_stations
        self.__read_data__()

    def __read_data__(self):
        """
        df_raw.columns: ['Timestamp', 'Discharge.Daily', 'Water.Level', 'Rainfall.Manual']
        """
        df_data_list = []
        min_num_train = None
        for station in self.all_stations:
            station_data_path = 'ID_' + str(station) + '.csv'
            df_raw = pd.read_csv(os.path.join(self.root_path, station_data_path), sep=';')
            df_raw = df_raw[['YYYY'] + ['MM'] + ['DD'] + [self.target]]
            df_raw['Timestamp'] = pd.to_datetime(df_raw[['YYYY', 'MM', 'DD']])
            df_raw.drop(['YYYY', 'MM', 'DD'], axis=1, inplace=True)
            df_raw.columns = ['Timestamp', station]
            df_data_list.append(df_raw)
            num_train = len(df_raw) - self.num_test - self.num_vali
            if min_num_train is None:
                min_num_train = num_train
            else:
                min_num_train = min(min_num_train, num_train)

        len_df_raw = min_num_train + self.num_vali + self.num_test
        df_data = None
        df_example = None
        for df_i, df_raw in enumerate(df_data_list):
            new_df_raw = df_raw.iloc[-len_df_raw:].copy()
            new_df_raw.reset_index(inplace=True)
            if df_example is None:
                df_example = new_df_raw
            cols_data = new_df_raw.columns[2:]
            station_df_data = new_df_raw[cols_data]

            if df_data is None:
                df_data = station_df_data
            else:
                df_data = pd.concat([df_data, station_df_data], axis=1)

        border1s = [0, min_num_train - self.seq_len, len_df_raw - self.num_test - self.seq_len]
        border2s = [min_num_train, min_num_train + self.num_vali, len_df_raw]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        scaled_columns = []
        for col in df_data.columns:
            scaler = self.scaler_dict[self.target]
            scaled_col = scaler.transform(df_data[[col]].values)
            scaled_columns.append(scaled_col)
        data = np.hstack(scaled_columns)

        df_stamp = df_example[['Timestamp']][border1:border2]
        df_stamp['Timestamp'] = pd.to_datetime(df_stamp['Timestamp'])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp['Timestamp'].apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp['Timestamp'].apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp['Timestamp'].apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp['Timestamp'].apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['Timestamp'], 1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['Timestamp'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)
        else:
            raise NotImplementedError

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]

        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]
        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler_dict[self.target].inverse_transform(data)