import os
import torch
import numpy as np
import pandas as pd
from datetime import datetime

data_root_path = 'none'
if data_root_path == 'none':
    raise ValueError('Please replace `root_path` into the absolute path of `raw_data` folder!')

file_path_dict = {
                 'PEMS04':f'{data_root_path}/PEMS04.npz',
                 'SeaLoop':f'{data_root_path}/SeaLoop.csv',
                 'nrel_al':f'{data_root_path}/nrel_al.csv',
                 'EPeMS':f'{data_root_path}/gla_pemsd7.csv'
                }
meta_info_dict= {
        'PEMS04': ['2018-01-01 00:00:00', '5m'], # 5min
        'SeaLoop':['2015-01-01 00:00:00','5m'],
        'nrel_al':['2016-01-01 00:00:00','5m'],
        'EPeMS':['2017-01-01 00:00:00','5m']
    }


class StandardScaler():
    def __init__(self, mean=None, std=None, channel_wise=False, stage1_node_idx = None, stage2_node_idx = None):
        self.mean = mean
        self.std = std
        self.channel_wise = channel_wise
        self.stage1_node_idx = stage1_node_idx
        self.stage2_node_idx = stage2_node_idx
        print('scale channel_wise :', channel_wise)
    
    def fit(self, data): # data: [T, N]
        if self.channel_wise:
            self.mean = data.mean(axis=0)
            self.std = data.std(axis=0)            
        else:
            self.mean = data.mean()
            self.std = data.std()
        return self

    def transform(self, data, stage=None):
        if self.channel_wise == True:
            if stage == 'stage1':
                node_idx = self.stage1_node_idx  
            elif stage == 'stage2':
                node_idx = self.stage2_node_idx
            else:
                node_idx = np.arange(self.mean.shape[0])
            if len(data.shape) == 2:
                assert node_idx is not None
                mean = self.mean[np.newaxis, node_idx]
                std = self.std[np.newaxis, node_idx]
            elif len(data.shape) == 3:
                assert node_idx is not None
                mean = self.mean[np.newaxis, node_idx, np.newaxis]
                std = self.std[np.newaxis, node_idx, np.newaxis]
            elif len(data.shape) == 4:
                assert node_idx is not None
                mean = self.mean[np.newaxis, np.newaxis, node_idx, np.newaxis]
                std = self.std[np.newaxis, np.newaxis, node_idx, np.newaxis]
            else:
                raise ValueError("input dimension error!!!")
        else:
            mean = self.mean
            std = self.std
        if type(data) == np.ndarray:
            return (data - mean) / std
        elif type(data) == torch.Tensor:
            return (data -torch.tensor(mean).to(data.device))/ torch.tensor(std).to(data.device)


    def inverse_transform(self, data, stage=None):
        if self.channel_wise == True:
            if stage == 'stage1':
                node_idx = self.stage1_node_idx  
            elif stage == 'stage2':
                node_idx = self.stage2_node_idx
            else:
                node_idx = np.arange(self.mean.shape[0])
            if len(data.shape) == 2:
                assert node_idx is not None
                mean = self.mean[np.newaxis, node_idx]
                std = self.std[np.newaxis, node_idx]
            elif len(data.shape) == 3:
                assert node_idx is not None
                mean = self.mean[np.newaxis, node_idx, np.newaxis]
                std = self.std[np.newaxis, node_idx, np.newaxis]
            elif len(data.shape) == 4:
                assert node_idx is not None
                mean = self.mean[np.newaxis, np.newaxis, node_idx, np.newaxis]
                std = self.std[np.newaxis, np.newaxis, node_idx, np.newaxis]
            else:
                raise ValueError("input dimension error!!!")
        else:
            mean = self.mean
            std = self.std
        if type(data) == np.ndarray:
            return (data * std) + mean
        elif type(data) == torch.Tensor:
            return (data * torch.tensor(std).to(data.device)) + torch.tensor(mean).to(data.device)




def load_grid(filename):
    gridfile = pd.read_csv(filename)
    df = gridfile[['inflow', 'outflow']]
    max_row_id = gridfile['row_id'].max()
    max_col_id = gridfile['column_id'].max()
    num_nodes = (max_row_id+1)*(max_col_id+1)

    timesolts = list(gridfile['time'][:int(gridfile.shape[0] / num_nodes)])
    len_time = len(timesolts)
    data = []
    for i in range(0, df.shape[0], len_time):
        data.append(df[i:i + len_time].values)
    data = np.array(data, dtype=float)
    return data.swapaxes(0, 1)

def load_dyna(filename):
    dynafile = pd.read_csv(filename)
    df = dynafile[['traffic_flow']]
    num_nodes = dynafile['entity_id'].max()+1

    timesolts = list(dynafile['time'][:int(dynafile.shape[0] / num_nodes)])
    len_time = len(timesolts)
    data = []
    for i in range(0, df.shape[0], len_time):
        data.append(df[i:i+len_time].values)
    data = np.array(data, dtype=np.float)
    return data.swapaxes(0, 1)

def generate_datetime_series(start_datetime_str, interval_str, n, col_name='date'):
    start_datetime = datetime.strptime(start_datetime_str, '%Y-%m-%d %H:%M:%S')
    interval = pd.to_timedelta(interval_str)
    date_time_series = [start_datetime + i * interval for i in range(n)]
    df = pd.DataFrame({col_name: date_time_series})
    return df

def generate_time_features(df_time, freq):
    # tod (time of day)
    time_ind = (df_time['date'].values - df_time['date'].values.astype('datetime64[D]')) / pd.to_timedelta(freq)
    t_of_d = time_ind.reshape([time_ind.shape[0],1])
    
    # dow (day of week)
    dow = df_time.date.apply(lambda row:row.dayofweek, 1)
    d_of_w = dow.values.reshape([dow.shape[0],1])
    
    return np.concatenate([t_of_d, d_of_w], axis=-1).astype(np.int64)

def load_raw_data(raw_data_path):
    df_time = None
    # load raw data
    if raw_data_path.endswith('npz'):
        data = np.load(raw_data_path)['data']
    elif raw_data_path.endswith('grid'):
        data = load_grid(raw_data_path)
    elif raw_data_path.endswith('dyna'):
        data = load_dyna(raw_data_path)
    elif raw_data_path.endswith('csv'):
        df_raw = pd.read_csv(raw_data_path)
        df_raw = df_raw.fillna(0.0)
        # assert 'date' == df_raw.columns[0], 'no Date in the csv!' # the fisrt column must be 'date'
        val_cols = df_raw.columns[1:] 
        df_raw.rename(columns={df_raw.columns[0]: 'date'}, inplace=True)
        df_data = df_raw[val_cols]
        data = df_data.values
        data = np.expand_dims(data, axis=-1)
        print('data.shape:',data.shape)
        df_raw['date'] = pd.to_datetime(df_raw['date'])
        df_time = df_raw[['date']]
    elif raw_data_path.endswith('h5'):
        ori = pd.read_hdf(raw_data_path)
        data = ori.values #(length, N)
        data = np.expand_dims(data, axis=-1)
        df_time = pd.DataFrame({'date':ori.index})
    return data, df_time



def generate_sliding_window_samples(data, time_data, index_ranges, input_length, predict_length, stride=1):
    data_results = []
    time_data_results =[]
    window_size = input_length + predict_length

    for start_idx, end_idx in index_ranges:
        segment = data[start_idx:end_idx]  # shape: (L_i, ...)
        segment_time = time_data[start_idx:end_idx]
        # ic(start_idx,end_idx)
        # ic(time_data.shape)

        L_i = segment.shape[0]
        if L_i < window_size:
            raise ValueError(f"data length {L_i} is less than window size {window_size}")

        windows = [
            segment[i:i+window_size]
            for i in range(0, L_i - window_size + 1, stride)
        ]
        windows = np.stack(windows, axis=0)  # shape: (n_i, window_size, ...)
        time_windows = [
            segment_time[i:i+window_size]
            for i in range(0, L_i - window_size + 1, stride)
        ]
        time_windows = np.stack(time_windows, axis=0)

        print(windows.shape)
        print(time_windows.shape)
        data_results.append(windows)
        time_data_results.append(time_windows)

    data_samples = np.concatenate(data_results, axis=0) # n_samples, window_size, ...
    time_data_samples = np.concatenate(time_data_results, axis=0) # n_samples, window_size, ...
    x_data = data_samples[:, 0:input_length]
    y_data = data_samples[:, input_length:] if predict_length>0 else x_data
    x_time = time_data_samples[:, 0:input_length]
    y_time = time_data_samples[:, input_length:] if predict_length>0 else x_time

    return x_data, x_time, y_data, y_time

def load_full_data_time(dataset_name, in_dim = 1):
    data_cache_path = f'datasets/{dataset_name}/full_data.npz'
    time_cache_path = f'datasets/{dataset_name}/full_time.npz'
    timestamp_cache_path = f'datasets/{dataset_name}/full_timestamp.csv'
    if not os.path.exists(f'datasets/{dataset_name}/'):
        os.makedirs(f'datasets/{dataset_name}')
    load_cache = False
    if os.path.exists(data_cache_path) and os.path.exists(time_cache_path) and os.path.exists(timestamp_cache_path) and load_cache:
        full_data = np.load(data_cache_path)['data']
        full_valid_mask = np.load(data_cache_path)['valid_mask']
        full_time = np.load(time_cache_path)['time_data']
        df_time = pd.read_csv(timestamp_cache_path)
        df_time['date'] = pd.to_datetime(df_time['date'])
        print('full data loaded:', full_data.shape)
        print('full time_data loaded:', full_time.shape)
        print('full time_stamp loaded:', df_time.shape)
    else:
        print(f'loading raw data of {dataset_name}...')
        raw_data_path = file_path_dict[dataset_name]
        data, df_time = load_raw_data(raw_data_path)
        filled_data = data
        full_valid_mask = None

        full_data = filled_data[...,:in_dim]
        print(f'Full Data shape:{full_data.shape}')
        
        start_time_str, freq = meta_info_dict[dataset_name]
        if df_time is None:
            df_time = generate_datetime_series(start_time_str, freq, full_data.shape[0], 'date')
        
        # delta = (df_time.loc[1, 'date'] - df_time.loc[0, 'date']).total_seconds() / 60
        # freq_minutes = int(round(delta))    
        # slices_per_day = int(24 * 60 / freq_minutes)
        full_time = generate_time_features(df_time, freq)
        print(f"Full Time shape: {full_time.shape}")
        np.savez_compressed(data_cache_path, data = full_data, valid_mask = full_valid_mask)
        np.savez_compressed(time_cache_path, time_data = full_time)
        df_time.to_csv(timestamp_cache_path, index=False)
        print('full data and time_data saved!')

    return full_data, full_time, df_time, full_valid_mask


def extract_weeks_and_adjust_test_start(df_time, val_start_idx):
    df_time = df_time.copy()
    df_time['date'] = pd.to_datetime(df_time['date'])
    df_time = df_time.sort_values('date').reset_index(drop=True)

    if len(df_time) < 2:
        raise ValueError("the length of time series is insufficient!")
    
    delta = (df_time.loc[1, 'date'] - df_time.loc[0, 'date']).total_seconds() / 60
    freq_minutes = int(round(delta))
    
    slices_per_day = int(24 * 60 / freq_minutes)
    slices_per_week = 7 * slices_per_day

    # fine idx of the first monday 00:00
    for start_idx in range(len(df_time)):
        ts = df_time.loc[start_idx, 'date']
        if ts.weekday() == 0 and ts.hour == 0 and ts.minute == 0:
            break
    else:
        return [], None  
    
    remaining = len(df_time) - start_idx
    num_weeks = remaining // slices_per_week
    end_idx = start_idx + num_weeks * slices_per_week

    valid_indices = list(range(start_idx, end_idx))

    if val_start_idx < start_idx or val_start_idx >= end_idx:
        adjusted_test_start = None
        raise SystemError
    else:
        offset = val_start_idx - start_idx
        week_offset = offset // slices_per_week
        adjusted_test_start = start_idx + (week_offset+1) * slices_per_week

    return valid_indices, adjusted_test_start


def load_stage_division(dataset_name, data_length, df_time, slices_per_week):
    stage_division_cache_path = f'datasets/{dataset_name}/{dataset_name}_stage_division.npz'
    if os.path.exists(stage_division_cache_path):
        stage_division = np.load(stage_division_cache_path)
        stage1_start_idx = stage_division['stage1_start_idx']
        stage2_start_idx = stage_division['stage2_start_idx']
        test_start_idx = stage_division['test_start_idx']
        stage_division = [stage1_start_idx, stage2_start_idx, test_start_idx]
    else:
        if dataset_name =='EPeMS':
            raise ValueError('the stage division of EPeMS should be loaded from STEV repo.')
               
        print('create stage_division ...')
        val_start_idx = int(data_length*0.6) if dataset_name in ['PEMS04','PEMS08'] else int(data_length*0.7) 
        print('val_start_idx:',val_start_idx)
        indices, new_test_idx =  extract_weeks_and_adjust_test_start(df_time, val_start_idx)
        stage2_start_idx = new_test_idx - slices_per_week        
        stage_division = [0, stage2_start_idx, new_test_idx]
        stage1_start_idx=0
        test_start_idx = new_test_idx
        np.savez_compressed(stage_division_cache_path,stage1_start_idx=0, stage2_start_idx=stage2_start_idx, test_start_idx=new_test_idx)
    
    description_txt_path = f'datasets/{dataset_name}/{dataset_name}_stage_division_description.txt'

    if not os.path.exists(description_txt_path):
        print('write txt:')
        with open(description_txt_path,'w') as f:
            f.write(f'stage1:[{stage1_start_idx} - {stage2_start_idx})\n')
            f.write(f'stage2_valid:[{stage2_start_idx} - {test_start_idx})\n')
            f.write(f'stage2_test:[{test_start_idx} - {data_length})\n')

    return stage_division


def build_stage_node_set(all_num_nodes, new_add_rate=0.2, retire_rate=0.05, seed=1):
    print('all_nodes:',all_num_nodes)
    np.random.seed(seed)
    rand = np.random.random
    rand_mask = rand(all_num_nodes) < new_add_rate
    newadd_set = np.where(rand_mask == True)[0]
    stage1_node_set = np.where(rand_mask == False)[0]

    retire_mask = rand(len(stage1_node_set)) < retire_rate
    remain_idx = np.where(retire_mask == False)[0]
    remain_set = stage1_node_set[remain_idx]
    retire_idx = np.where(retire_mask == True)[0]
    retire_set = stage1_node_set[retire_idx]
    stage2_node_set = list(set(remain_set) | set(newadd_set))
    stage2_node_set = sorted(stage2_node_set)
    
    print('stage1_node_set:',f'len={len(stage1_node_set)}')
    print('remain_set:',f'len={len(remain_set)}')
    print('newadd_set:',f'len={len(newadd_set)}')
    print('retire_set:',f'len={len(retire_set)}')
    print('stage:',len(stage1_node_set),'->',len(stage2_node_set), f'(+{len(newadd_set)}-{len(stage1_node_set)-len(remain_set)})')
    
    return stage1_node_set, stage2_node_set, newadd_set, remain_set, retire_set

def load_node_division(dataset_name, all_num_nodes, retire_rate=0.05, new_add_rate=0.2, seed=1):
    if dataset_name in ['EPeMS']:
        node_division_cache_path = f'datasets/{dataset_name}/{dataset_name}_node_division.npz'
        description_txt_path = f'datasets/{dataset_name}/{dataset_name}_node_division_description.txt'        
    else:
        node_division_cache_path = f'datasets/{dataset_name}/{dataset_name}_stage_node_division_seed{seed}_newAdd{new_add_rate}_retire{retire_rate}_remain.npz'
        description_txt_path = f'datasets/{dataset_name}/{dataset_name}_node_division_seed{seed}_newAdd{new_add_rate}_retire{retire_rate}_remain_description.txt'  
    if os.path.exists(node_division_cache_path):
        node_division = np.load(node_division_cache_path)
        print(f'{dataset_name}: node division loaded!')
        stage1_node_set = node_division['stage1_node_set']
        stage2_node_set = node_division['stage2_node_set']
        newadd_set = node_division['newadd_set']
        remain_set = node_division['remain_set']
        retire_set = node_division['retire_set']
    else:
        stage1_node_set, stage2_node_set, newadd_set, remain_set, retire_set = build_stage_node_set(all_num_nodes, new_add_rate, retire_rate, seed)
        np.savez_compressed(node_division_cache_path,stage1_node_set=stage1_node_set, stage2_node_set=stage2_node_set, newadd_set=newadd_set, remain_set=remain_set, retire_set=retire_set)
    

    if not os.path.exists(description_txt_path):
        with open(description_txt_path,'w') as f:
            f.write(f'node_num: {len(stage1_node_set)} -> {len(stage2_node_set)}({len(stage1_node_set)} - {len(retire_set)} + {len(newadd_set)})\n')
            f.write(f'stage1_node_idx: len={len(stage1_node_set)}\n')
            f.write(f'\t{stage1_node_set}\n')
            f.write(f'stage2_node_idx: len={len(stage2_node_set)}\n')
            f.write(f'\t{stage2_node_set}\n')
            f.write(f'remain_in_stage2_node_idx: len={len(remain_set)}\n')
            f.write(f'\t{remain_set}\n')
            f.write(f'retire_in_stage2_node_idx: len={len(retire_set)}\n')
            f.write(f'\t{retire_set}\n')
            f.write(f'newadd_in_stage2_node_idx: len={len(newadd_set)}\n')
            f.write(f'\t{newadd_set}\n')

    
    return stage1_node_set, stage2_node_set, newadd_set, remain_set, retire_set



def get_full_scaler_info(full_data, stage2_start_idx, test_start_idx, remain_idx, retire_idx, newadd_idx, channel_wise_norm = False):
    L, N = full_data.shape
    if channel_wise_norm:
        final_mean = np.zeros((N,),dtype=np.float64)
        final_std = np.zeros((N,),dtype=np.float64)
        remain_mean = full_data[0:test_start_idx, remain_idx].mean(axis=0)
        remain_std = full_data[0:test_start_idx, remain_idx].std(axis=0) 
        newadd_mean = full_data[stage2_start_idx:test_start_idx, newadd_idx].mean(axis=0)
        newadd_std = full_data[stage2_start_idx:test_start_idx, newadd_idx].std(axis=0)
        final_mean[remain_idx] = remain_mean
        final_std[remain_idx] = remain_std
        final_mean[newadd_idx] = newadd_mean
        final_std[newadd_idx] = newadd_std
        if len(retire_idx)>0:
            retire_mean = full_data[0:stage2_start_idx, retire_idx].mean(axis=0)
            retire_std = full_data[0:stage2_start_idx, retire_idx].std(axis=0) 
            final_mean[retire_idx] = retire_mean
            final_std[retire_idx] = retire_std
            stage1_node_idx = sorted(np.concatenate([remain_idx, retire_idx], axis=0).tolist())
        else:
            stage1_node_idx = remain_idx
        stage2_node_idx = sorted(np.concatenate([remain_idx, newadd_idx], axis=0).tolist())
        stage1_data = full_data[0:stage2_start_idx, stage1_node_idx].flatten()
        stage2_data = full_data[stage2_start_idx:test_start_idx, stage2_node_idx].flatten() 
        stage1_mean = stage1_data.mean(axis=0)
        stage2_mean = stage2_data.mean(axis=0)
        stage1_std = stage1_data.std(axis=0)
        stage2_std = stage2_data.std(axis=0)
    else:
        all_available_data = []
        remain_data = full_data[0:test_start_idx, remain_idx]
        newadd_data = full_data[stage2_start_idx:test_start_idx, newadd_idx]
        all_available_data.append(remain_data.flatten())
        all_available_data.append(newadd_data.flatten())
        if len(retire_idx)>0:
            retire_data = full_data[0:stage2_start_idx, retire_idx]
            all_available_data.append(retire_data.flatten())
            stage1_node_idx = sorted(np.concatenate([remain_idx, retire_idx], axis=0).tolist())
        else:
            stage1_node_idx = remain_idx
        all_available_data = np.concatenate(all_available_data, axis=0)
        final_mean = all_available_data.mean()
        final_std = all_available_data.std()
        stage2_node_idx = sorted(np.concatenate([remain_idx, newadd_idx], axis=0).tolist())
        stage1_data = full_data[0:stage2_start_idx, stage1_node_idx].flatten()
        stage2_data = full_data[stage2_start_idx:test_start_idx, stage2_node_idx].flatten() 
        stage1_mean = stage1_data.mean()
        stage2_mean = stage2_data.mean()
        stage1_std = stage1_data.std()
        stage2_std = stage2_data.std()
    return final_mean, final_std, stage1_mean, stage1_std, stage2_mean, stage2_std


def load_stage_dataloader(dataset_name, full_data, full_time_data, stage2_start_idx, test_start_idx, stage1_node_idx, stage2_node_idx,
                               stage='full', scaler=None, batch_size=32, valid_mask=None, input_length =12, predict_length =12,
                               channel_wise_norm=False, norm_y = False, scaler_mean= None, scaler_std = None, scaler_type='zscore'):
    L, N, C = full_data.shape
    print('full_data.shape:',full_data.shape)
    if stage == "full":
        time_start, time_end = 0, L
        split_line1, split_line2 = stage2_start_idx, test_start_idx
        node_idx = list(range(N))
    elif stage == "stage1":
        time_start, time_end = 0, test_start_idx
        split_line1, split_line2 = stage2_start_idx, test_start_idx
        node_idx = stage1_node_idx
    elif stage == "stage2":
        time_start, time_end = stage2_start_idx, L
        slices_per_day = full_time_data[...,0].max() + 1
        print('slices_per_day:', slices_per_day)
        if dataset_name in ['EElectricity','EWeather']:
            train_days = 7
            valid_days = 3
        elif dataset_name in ['EPeMS']:
            train_days = 3
            valid_days = 2
        else:
            train_days = 6
            valid_days = 1
        split_line1 = slices_per_day * train_days
        split_line2 = slices_per_day * (train_days + valid_days)
        node_idx = stage2_node_idx
        
    train_indices = list(range(0,split_line1))
    
    ready_data = full_data[time_start:time_end, node_idx]
    ready_time_data = full_time_data[time_start:time_end]
    ready_valid_mask = valid_mask[time_start:time_end, node_idx] if valid_mask is not None else None  

    train_data = ready_data[0:split_line1]
    train_time_data = ready_time_data[0:split_line1]
    train_valid_mask = ready_valid_mask[0:split_line1] if valid_mask is not None else None

    val_data = ready_data[split_line1:split_line2]
    val_time_data = ready_time_data[split_line1:split_line2]
    val_valid_mask = ready_valid_mask[split_line1:split_line2] if valid_mask is not None else None

    if split_line2 < (time_end-time_start):
        test_data = ready_data[split_line2:]
        test_time_data = ready_time_data[split_line2:]
        test_valid_mask = ready_valid_mask[split_line2:] if valid_mask is not None else None
    else: # for stage1
        test_data = None
        test_time_data = None
        test_valid_mask = None

    if channel_wise_norm:
        if scaler_mean is None and scaler_std is None:
            scaler_mean = train_data.mean(axis=0)
            scaler_std = train_data.std(axis=0)
        self_scaler = StandardScaler(mean=scaler_mean, std=scaler_std, channel_wise=channel_wise_norm, stage1_node_idx=stage1_node_idx, stage2_node_idx=stage2_node_idx)
    else:
        scaler_mean = train_data.mean()
        scaler_std = train_data.std()
        self_scaler = StandardScaler(mean=scaler_mean, std=scaler_std, channel_wise=channel_wise_norm)

    def generate_samples(data, time_data, valid_mask, input_length, predict_length):
        window_size = input_length + predict_length
        L = data.shape[0]
        stride = 1
        if L < window_size:
            raise ValueError(f"data length {L} is less than window size {window_size}")
 
        windows = [
            data[i:i+window_size]
            for i in range(0, L - window_size + 1, stride)
        ]
        windows = np.stack(windows, axis=0)  # shape: (n, window_size, ...)
        time_windows = [
            time_data[i:i+window_size]
            for i in range(0, L - window_size + 1, stride)
        ]
        time_windows = np.stack(time_windows, axis=0)
        print(windows.shape)
        print(time_windows.shape)

        data_samples = windows # n_samples, window_size, ...
        time_data_samples = time_windows # n_samples, window_size, ...
        x_data = data_samples[:, 0:input_length]
        y_data = data_samples[:, input_length:] if predict_length>0 else x_data
        x_time = time_data_samples[:, 0:input_length]
        y_time = time_data_samples[:, input_length:] if predict_length>0 else x_time

        if valid_mask is not None:
            valid_mask_windows = [
                valid_mask[i:i+window_size]
                for i in range(0, L - window_size + 1, stride)
            ]
            valid_mask_windows = np.stack(valid_mask_windows, axis=0)
            print(valid_mask_windows.shape)
            
            valid_mask_samples = valid_mask_windows
            x_valid_mask = valid_mask_samples[:, 0:input_length]
            y_valid_mask = valid_mask_samples[:, input_length:] if predict_length>0 else x_valid_mask
        else:
            x_valid_mask = None
            y_valid_mask = None
        return x_data, x_time, y_data, y_time, x_valid_mask, y_valid_mask
    

    
    x_train, x_train_time, y_train, y_train_time, x_train_vmask, y_train_vmask = generate_samples(train_data, train_time_data, train_valid_mask, input_length, predict_length)
    x_val, x_val_time, y_val, y_val_time, x_val_vmask, y_val_vmask  = generate_samples(val_data, val_time_data, val_valid_mask, input_length, predict_length)
    print(f"Trainset:\tx-{x_train.shape} x_time-{x_train_time.shape}\ty-{y_train.shape}")
    print(f"Valset:  \tx-{x_val.shape} x_time-{x_val_time.shape}\ty-{y_val.shape}")

    if test_data is not None:
        x_test, x_test_time, y_test, y_test_time, x_val_vmask, y_test_vmask  = generate_samples(test_data, test_time_data, test_valid_mask, input_length, predict_length)
        print(f"Testset:\tx-{x_test.shape} x_time-{x_test_time.shape}\ty-{y_test.shape}")

    
    if scaler is None:
        x_train = self_scaler.transform(x_train, stage=stage)
        x_val = self_scaler.transform(x_val, stage=stage)
        x_test = self_scaler.transform(x_test, stage=stage) if test_data is not None else None
        if norm_y:
            y_train = self_scaler.transform(y_train, stage=stage)
            y_val = self_scaler.transform(y_val, stage=stage)
            y_test = self_scaler.transform(y_test, stage=stage) if test_data is not None else None
    else:
        x_train = scaler.transform(x_train, stage=stage)
        x_val = scaler.transform(x_val, stage=stage)
        x_test = scaler.transform(x_test, stage=stage) if test_data is not None else None
        if norm_y:
            y_train = scaler.transform(y_train, stage=stage)
            y_val = scaler.transform(y_val, stage=stage)
            y_test = scaler.transform(y_test, stage=stage) if test_data is not None else None



    
    trainset = torch.utils.data.TensorDataset(
        torch.FloatTensor(x_train), torch.LongTensor(x_train_time), torch.FloatTensor(y_train)
    )
    valset = torch.utils.data.TensorDataset(
        torch.FloatTensor(x_val), torch.LongTensor(x_val_time), torch.FloatTensor(y_val)
    )
    trainset_loader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True
    )
    valset_loader = torch.utils.data.DataLoader(
        valset, batch_size=batch_size, shuffle=False
    )

    if x_test is not None:
        if test_valid_mask is not None:
            testset = torch.utils.data.TensorDataset(
                torch.FloatTensor(x_test), torch.LongTensor(x_test_time), torch.FloatTensor(y_test), torch.BoolTensor(y_test_vmask)
            )
        else:
            testset = torch.utils.data.TensorDataset(
                torch.FloatTensor(x_test), torch.LongTensor(x_test_time), torch.FloatTensor(y_test)
            )
        testset_loader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False
        )

        dataloader ={
            'train_loader':trainset_loader,
            'val_loader':valset_loader, 
            'test_loader':testset_loader 
        }
    else:
        dataloader ={
            'train_loader':trainset_loader,
            'val_loader':valset_loader
        }
    
    return dataloader, self_scaler, train_indices
