import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from utils.timefeatures import time_features
import warnings
from utils.constants import Constants

warnings.filterwarnings('ignore')


class Dataset_mekong_target_station(Dataset):
    def __init__(self, args, root_path, flag='train', size=None,
                 features='M', data_path='Stung Treng.csv',
                 target='Value', scaler_dict=None, timeenc=0, freq='d'):
        # size [seq_len, label_len, pred_len]
        self.args = args
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scaler_dict = scaler_dict
        self.timeenc = timeenc
        self.freq = freq
        self.cycle = args.cycle
        self.num_vali = args.num_vali
        self.num_test = args.num_test

        self.root_path = root_path
        self.data_path = data_path
        self.all_stations = Constants(args).all_stations
        self.__read_data__()

    def __read_data__(self):
        """
        df_raw.columns: ['Timestamp', 'Discharge.Daily', 'Water.Level', 'Rainfall.Manual']
        """

        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
        cols = list(df_raw.columns)
        cols.remove(self.target)
        cols.remove('Timestamp')
        if self.features == 'S':
            df_raw = df_raw[['Timestamp'] + [self.target]]
        else:
            df_raw = df_raw[['Timestamp'] + cols + [self.target]]
            # check NaN value and remove the col if it has
            cols_to_check = df_raw.columns[1:4]
            nan_cols = [col for col in cols_to_check if df_raw[col].isna().any()]
            df_raw = df_raw.drop(columns=nan_cols)
            assert len(df_raw.columns) >= 2, "data all columns have NaN value"

        num_train = len(df_raw) - self.num_test - self.num_vali
        border1s = [0, num_train - self.seq_len, len(df_raw) - self.num_test - self.seq_len]
        border2s = [num_train, num_train + self.num_vali, len(df_raw)]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        cols_data = df_raw.columns[1:]
        df_data = df_raw[cols_data]

        scaled_columns = []
        for col in df_data.columns:
            scaler = self.scaler_dict[col]
            scaled_col = scaler.transform(df_data[[col]].values)
            scaled_columns.append(scaled_col)
        data = np.hstack(scaled_columns)

        df_stamp = df_raw[['Timestamp']][border1:border2]
        df_stamp['Timestamp'] = pd.to_datetime(df_stamp['Timestamp'])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp['Timestamp'].apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp['Timestamp'].apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp['Timestamp'].apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp['Timestamp'].apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['Timestamp'], 1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['Timestamp'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)
        else:
            raise NotImplementedError

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp
        # add cycle
        self.cycle_index = (np.arange(len(data)) % self.cycle)[border1:border2]

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]
        cycle_index = self.cycle_index[s_end]
        return seq_x, seq_y, seq_x_mark, seq_y_mark, cycle_index

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler_dict[self.target].inverse_transform(data)


class Dataset_mekong_for_gnn(Dataset):
    def __init__(self, args, root_path, flag='train', size=None,
                 features='M', data_path='Stung Treng.csv',
                 target='Value', scaler_dict=None, timeenc=0, freq='d'):
        # size [seq_len, label_len, pred_len]
        self.args = args
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scaler_dict = scaler_dict
        self.timeenc = timeenc
        self.freq = freq
        self.num_vali = args.num_vali
        self.num_test = args.num_test

        self.root_path = root_path
        self.data_path = data_path
        self.all_stations = Constants(args).all_stations
        self.__read_data__()

    def __read_data__(self):
        """
        df_raw.columns: ['Timestamp', 'Discharge.Daily', 'Water.Level', 'Rainfall.Manual']
        """
        df_data_list = []
        min_num_train = None
        for station in self.all_stations:
            station_data_path = str(station) + '.csv'
            df_raw = pd.read_csv(os.path.join(self.root_path, station_data_path))
            df_raw = df_raw[['Timestamp'] + [self.target]]
            df_raw.columns = ['Timestamp', station]
            df_data_list.append(df_raw)
            num_train = len(df_raw) - self.num_test - self.num_vali
            if min_num_train is None:
                min_num_train = num_train
            else:
                min_num_train = min(min_num_train, num_train)

        len_df_raw = min_num_train + self.num_vali + self.num_test
        df_data = None
        df_example = None
        for df_i, df_raw in enumerate(df_data_list):
            new_df_raw = df_raw.iloc[-len_df_raw:].copy()
            new_df_raw.reset_index(inplace=True)
            if df_example is None:
                df_example = new_df_raw
            cols_data = new_df_raw.columns[2:]
            station_df_data = new_df_raw[cols_data]

            if df_data is None:
                df_data = station_df_data
            else:
                df_data = pd.concat([df_data, station_df_data], axis=1)

        border1s = [0, min_num_train - self.seq_len, len_df_raw - self.num_test - self.seq_len]
        border2s = [min_num_train, min_num_train + self.num_vali, len_df_raw]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        scaled_columns = []
        for col in df_data.columns:
            scaler = self.scaler_dict[self.target]
            scaled_col = scaler.transform(df_data[[col]].values)
            scaled_columns.append(scaled_col)
        data = np.hstack(scaled_columns)

        df_stamp = df_example[['Timestamp']][border1:border2]
        df_stamp['Timestamp'] = pd.to_datetime(df_stamp['Timestamp'])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp['Timestamp'].apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp['Timestamp'].apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp['Timestamp'].apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp['Timestamp'].apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['Timestamp'], 1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['Timestamp'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)
        else:
            raise NotImplementedError

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]

        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler_dict[self.target].inverse_transform(data)


class Dataset_mekong_stations_list(Dataset):
    def __init__(self, args, root_path, flag='train', size=None,
                 features='M', data_path='Stung Treng.csv',
                 target='Value', scaler_dict=None, timeenc=0, freq='d', stations_list=None):
        # size [seq_len, label_len, pred_len]
        self.args = args
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scaler_dict = scaler_dict
        self.timeenc = timeenc
        self.freq = freq
        self.cycle = args.cycle
        self.num_vali = args.num_vali
        self.num_test = args.num_test

        self.root_path = root_path
        self.data_path = data_path
        self.stations_list = stations_list
        self.__read_data__()

    def __read_data__(self):
        """
        df_raw.columns: ['Timestamp', 'Discharge.Daily', 'Water.Level', 'Rainfall.Manual']
        """
        df_data_list = []
        min_num_train = None
        for station in self.stations_list:
            station_data_path = str(station) + '.csv'
            df_raw = pd.read_csv(os.path.join(self.root_path, station_data_path))
            cols = list(df_raw.columns)
            cols.remove(self.target)
            cols.remove('Timestamp')
            if self.features == 'S':
                df_raw = df_raw[['Timestamp'] + [self.target]]
            else:
                df_raw = df_raw[['Timestamp'] + cols + [self.target]]
                # check NaN value and remove the col if it has
                cols_to_check = df_raw.columns[1:4]
                nan_cols = [col for col in cols_to_check if df_raw[col].isna().any()]
                df_raw = df_raw.drop(columns=nan_cols)
                assert len(df_raw.columns) >= 2, "data all columns have NaN value"
            df_data_list.append(df_raw)
            num_train = len(df_raw) - self.num_test - self.num_vali
            if min_num_train is None:
                min_num_train = num_train
            else:
                min_num_train = min(min_num_train, num_train)

        len_df_raw = min_num_train + self.num_vali + self.num_test
        df_example = None
        for df_i, df_raw in enumerate(df_data_list):
            new_df_raw = df_raw.iloc[-len_df_raw:].copy()
            if df_example is None:
                df_example = new_df_raw
            cols_data = new_df_raw.columns[1:]
            df_data = new_df_raw[cols_data]
            scaled_columns = []
            for col in df_data.columns:
                scaler = self.scaler_dict[col]
                scaled_col = scaler.transform(df_data[[col]].values)
                scaled_columns.append(scaled_col)
            data = np.hstack(scaled_columns)
            df_data_list[df_i] = data

        border1s = [0, min_num_train - self.seq_len, len_df_raw - self.num_test - self.seq_len]
        border2s = [min_num_train, min_num_train + self.num_vali, len_df_raw]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        df_stamp = df_example[['Timestamp']][border1:border2]
        df_stamp['Timestamp'] = pd.to_datetime(df_stamp['Timestamp'])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp['Timestamp'].apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp['Timestamp'].apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp['Timestamp'].apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp['Timestamp'].apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['Timestamp'], 1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['Timestamp'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)
        else:
            raise NotImplementedError

        self.data_x_list = []
        self.data_y_list = []

        for df_i, data in enumerate(df_data_list):
            self.data_x_list.append(data[border1:border2])
            self.data_y_list.append(data[border1:border2])

        self.data_stamp = data_stamp

    def __getitem__(self, index):
        """
        :param index:
        :return: (list): [input_res, true_res]
            input_res (list of numpy array): [x0, x1, ..., xn]
            true_res (list of numpy array): [y0, y1, ..., yn]
        """
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        input_res = []
        true_res = []

        for i in range(len(self.stations_list)):
            x = self.data_x_list[i][s_begin:s_end]
            input_res.append(x)
            y = self.data_y_list[i][r_begin:r_end]
            true_res.append(y)

        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return [input_res, true_res, seq_x_mark, seq_y_mark]

    def __len__(self):
        return len(self.data_x_list[0]) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler_dict[self.target].inverse_transform(data)
