import os
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
#from tools.tool import StandardScaler
import warnings
import math
import glob
from scipy.io import arff
from statsmodels.tsa.stattools import adfuller
warnings.filterwarnings('ignore')

class Dataset_forecast(Dataset):  # no dataset for forecasting
    def __init__(self, root_path, tasks,  dataset, file_name = 'ETTh1.csv', flag = 'TRAIN',
                 features='S', target='OT', pred_len = 96, scale=True):
        # size [seq_len, label_len, pred_len]
        # info
       
        self.seq_len = pred_len
        self.pred_len = pred_len
        # init
        assert flag in ['TRAIN', 'TEST', 'VAL']
        type_map = {'TRAIN': 0, 'VAL': 1, 'TEST': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.tasks = tasks

        self.root_path = root_path
        self.data_path = file_name
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
        #cols = list(df_raw.columns)
        #cols.remove(self.target)
        #cols.remove('date')
        #df_raw = df_raw[['date'] + cols + [self.target]]
        
        # print(cols)
        length = len(df_raw)    #split the dataset as 7:1:2
        num_train = int(length *0.7)
        num_test = int(length *0.2)
        num_vali = len(df_raw) - num_train - num_test
        #num_vali = int(length *0.1)
        border1s = [0, num_train - self.seq_len, num_train + num_vali-self.seq_len]
        border2s = [num_train, num_train + num_vali, num_train + num_vali + num_test]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end
        r_end = r_begin + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
    
        return seq_x, seq_y

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)

def gen_ano_train_data(all_train_data):
    maxl = np.max([ len(all_train_data[k]) for k in all_train_data ])
    pretrain_data = []
    for k in all_train_data:
        train_data = pad_nan_to_target(all_train_data[k], maxl, axis=0)
        pretrain_data.append(train_data)
    pretrain_data = np.expand_dims(np.stack(pretrain_data), 2)
    return pretrain_data

class Dataset_kpi(Dataset):  # no dataset for forecasting
    def __init__(self, root_path, tasks,  dataset, file_name = 'ETTh1.csv', flag = 'TRAIN',
                 features='S', target='OT', window_size = 128, scale=True):
        # size [seq_len, label_len, pred_len]
        # info
       
        self.window_size = window_size
        self.flag = flag
        # init
        self.scale = scale
        self.tasks = tasks

        self.root_path = root_path
        self.data_path = file_name
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        if self.flag == 'TRAIN' or self.flag == 'VAL':
            df_raw = pd.read_csv(os.path.join(self.root_path, 'phase2_train.csv'))
            df_raw = df_raw.set_index(['KPI ID', 'timestamp']).sort_index()
            self.x = []
            self.y = []
            self.group = 0
            for name, df in df_raw.groupby(level=0):
                data = df['value'].to_numpy()
                label = df['label'].to_numpy()
                timestamp = df.index.get_level_values(1)
                data = np.expand_dims(data, -1)
                self.scaler.fit(data)
                value = self.scaler.transform(data)
                self.group += 1
                for i in range(0, len(data)-self.window_size, self.window_size//10):
                    self.x.append(value[i:i+self.window_size])
                    self.y.append(label[i:i+self.window_size])
        else:
            df_raw = pd.read_hdf(os.path.join(self.root_path, 'phase2_ground_truth.hdf'))
            df_raw = df_raw.set_index(['KPI ID', 'timestamp']).sort_index()
            self.x = []
            self.y = []
            self.group = 0
            for name, df in df_raw.groupby(level=0):
                data = df['value'].to_numpy()
                label = df['label'].to_numpy()
                timestamp = df.index.get_level_values(1)
                data = np.expand_dims(data, -1)
                self.scaler.fit(data)
                value = self.scaler.transform(data)
                self.group += 1
                for i in range(0, len(data)-self.window_size, self.window_size):
                    self.x.append(value[i:i+self.window_size])
                    self.y.append(label[i:i+self.window_size])
        
        #print(len(self.x))    
        '''
        self.value = np.expand_dims(data, -1)
        self.label = df_raw['label'].values
        self.stamp = df_raw['timestamp'].values
        if self.scale:
            self.scaler.fit(self.value)
            self.value = self.scaler.transform(self.value)
        '''
        if self.flag == 'TRAIN':
            self.data_x = self.x[:int(len(self.x)*0.8)]
            self.data_y = self.y[:int(len(self.x)*0.8)]
        elif self.flag == 'VAL':
            self.data_x = self.x[int(len(self.x)*0.8):]
            self.data_y = self.y[int(len(self.y)*0.8):]
        else:
            self.data_x = self.x
            self.data_y = self.y
        
        '''
        self.x = []
        self.y = []
       
        for i in range(0, len(self.data_x)-self.window_size, self.window_size//2):
            self.x.append(self.data_x[i:i+self.window_size])
            self.y.append(self.data_y[i:i+self.window_size])
        '''
            
    def __getitem__(self, index):
        #s_begin = index
        #s_end = s_begin + self.window_size

        #seq_x = self.data_x[s_begin:s_end]
        #seq_y = self.data_y[s_begin:s_end]
        seq_x = self.data_x[index]
        seq_y = self.data_y[index]
        return seq_x, seq_y

    def __len__(self):
        return len(self.data_x) - 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)

class Dataset_yahoo(Dataset):  # no dataset for forecasting
    def __init__(self, root_path, tasks,  dataset, file_name = 'ETTh1.csv', flag = 'TRAIN',
                 features='S', target='OT', window_size = 128, scale=True):
        # size [seq_len, label_len, pred_len]
        # info
        
        self.window_size = window_size
        self.flag = flag
        # init
        self.scale = scale
        self.tasks = tasks

        self.root_path = root_path
        self.data_path = file_name
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()

        files_a1 = glob.glob(os.path.join(self.root_path, 'A1Benchmark/real_*.csv'))
        files_a1.sort()
        files_a2 = glob.glob(os.path.join(self.root_path, 'A2Benchmark/synthetic_*.csv'))
        files_a2.sort()
        files_a3 = glob.glob(os.path.join(self.root_path, 'A3Benchmark/A3Benchmark-TS*.csv'))
        files_a3.sort()
        files_a4 = glob.glob(os.path.join(self.root_path, 'A4Benchmark/A4Benchmark-TS*.csv'))
        files_a4.sort()
    
        train_data = []
        train_labels = []
        test_data = []
        test_labels = []
        val_data = []
        val_labels = []
        files_1 = files_a1 + files_a2
        files_2 = files_a3 + files_a4
        for fn in files_1:
            df = pd.read_csv(fn)
            value = df['value'].values
            #value = np.expand_dims(value, -1)
            label = df['is_anomaly'].values
            length = len(value)
            num_train = int(length *0.7)
            num_test = int(length *0.2)
            num_vali = length - num_train - num_test
            #print()
            for i in range(0, num_train-self.window_size, self.window_size//10):
                train_data.append(value[i:i+self.window_size])
                train_labels.append(label[i:i+self.window_size])
            for i in range(num_train-self.window_size, num_train + num_vali-self.window_size, self.window_size//10):
                val_data.append(value[i:i+self.window_size])
                val_labels.append(label[i:i+self.window_size])
            for i in range(num_train + num_vali-self.window_size, length-self.window_size, self.window_size):
                test_data.append(value[i:i+self.window_size])
                test_labels.append(label[i:i+self.window_size])
        for fn in files_2:
            df = pd.read_csv(fn)
            value = df['value'].values
            #value = np.expand_dims(value, -1)
            label = df['anomaly'].values

            length = len(value)
            num_train = int(length *0.7)
            num_test = int(length *0.2)
            num_vali = length - num_train - num_test
            #print()
            for i in range(0, num_train-self.window_size, self.window_size//10):
                train_data.append(value[i:i+self.window_size])
                train_labels.append(label[i:i+self.window_size])
            for i in range(num_train-self.window_size, num_train + num_vali-self.window_size, self.window_size//10):
                val_data.append(value[i:i+self.window_size])
                val_labels.append(label[i:i+self.window_size])
            for i in range(num_train + num_vali-self.window_size, length-self.window_size, self.window_size):
                test_data.append(value[i:i+self.window_size])
                test_labels.append(label[i:i+self.window_size])
            #test_data = value[(num_train + num_vali):]
            #test
            #data.append(df['value'].values)
            #label.append(df['is_anomaly'].values)
        
        #print(len(self.x))    
        if self.flag == 'TRAIN':
            self.data_x = train_data
            self.data_y = train_labels
        elif self.flag == 'VAL':
            self.data_x = val_data
            self.data_y = val_labels
        else:
            self.data_x = test_data
            self.data_y = test_labels
        
        if self.scale:
            self.scaler.fit(train_data)
            self.data_x = self.scaler.transform(self.data_x)
        self.data_x = np.expand_dims(self.data_x, -1)
        '''
        self.x = []
        self.y = []
       
        for i in range(0, len(self.data_x)-self.window_size, self.window_size//2):
            self.x.append(self.data_x[i:i+self.window_size])
            self.y.append(self.data_y[i:i+self.window_size])
        '''
            
    def __getitem__(self, index):
        #s_begin = index
        #s_end = s_begin + self.window_size

        #seq_x = self.data_x[s_begin:s_end]
        #seq_y = self.data_y[s_begin:s_end]
        seq_x = self.data_x[index]
        #seq_x = np.expand_dims(seq_x, -1)
        seq_y = self.data_y[index]
        return seq_x, seq_y

    def __len__(self):
        return len(self.data_x) - 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)

class Dataset_classifiction_multi(Dataset):
    def __init__(self, root_path, tasks, dataset, file_name='ETTh1', flag='TRAIN', features = 'M', target = 'OT', pred_len = 96, scale=True):
        # size [seq_len, label_len, pred_len]
        # info
        # init
        self.flag = flag
        self.scale = scale
        self.tasks = tasks
        self.root_path = root_path
        self.data_path = dataset
        self.file_name = file_name
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        if self.flag == 'TRAIN' or self.flag == 'VAL':
            string = 'TRAIN'
        else:
            string = 'TEST'
        '''
        for i in range(1, self.input_dim+1):
            raw, _ = arff.loadarff(os.path.join(self.root_path, self.data_path + '/' + self.file_name+ 'Dimension' + str(i) + '_' + string + '.arff'))
            df_raw = df = pd.DataFrame(raw)
            cols = list(df_raw.columns)
            cols_data = df_raw[cols[:-1]].astype(float)
            cols_target = df_raw[cols[-1]]
            #print(cols_target)
            if i == 1:
                data = np.expand_dims(cols_data.values, -1)
            else:
                data = np.concatenate([data, np.expand_dims(cols_data.values, -1)], axis=2)  
        '''
        raw, _ = arff.loadarff(os.path.join(self.root_path, self.data_path + '/' + self.file_name+ '_' + string + '.arff'))
        df_raw = df = pd.DataFrame(raw)
        #df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
        cols = list(df_raw.columns)
        #print(cols)
        cols_data = df_raw[cols[:-1]]
        data = np.array([np.array([np.array([*v]) for v in values[0]]) for values in cols_data.values])
        data = data.transpose(0, 2, 1)
        #print(data.shape)
        cols_target = df_raw[cols[-1]]
        #print(cols_target.values)
        #df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
        #print(data.shape)
        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
        #if self.scale:
        #    data = (data - np.mean(data))/np.std(data)
        if self.flag == 'TRAIN':
            self.data_x = data[0:int(0.8*len(data))]
            self.data_y = cols_target.values[0:int(0.8*len(data))]
        elif self.flag == 'VAL':
            self.data_x = data[int(0.8*len(data)):]
            self.data_y = cols_target.values[int(0.8*len(data)):]
        else:
            self.data_x = data
            self.data_y = cols_target.values
        arr, index = np.unique(cols_target.values, return_counts=True)
        self.nb_classes = len(index)
        dic = [i for i in range(len(arr))]
        self.dict = dict(zip(arr, dic))
        #print(self.dict)
        #for i in range(len(index)):
        #    self.dict[index[i]] = i
        #self.data_x = data
        #self.data_y = cols_target.values
 
    def __getitem__(self, index):
        #seq_x_mark = self.data_stamp[s_begin:s_end]
        #seq_y_mark = self.data_stamp[r_begin:r_end]
        x = self.data_x[index]
        if x.shape[0]>1000:
            x = x[0:1000]
        if self.scale:
            x = (x - np.mean(x))/np.std(x)    
        y = self.data_y[index]
        y = self.dict[y]
        #x = np.expand_dims(x, -1)
        #y = np.expand_dims(y, -1)

        return x, y

    def __len__(self):
        return len(self.data_x)

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)
    
class Dataset_classifiction(Dataset):
    def __init__(self, root_path, tasks, dataset, file_name='ETTh1', flag='TRAIN', features = 'M', target = 'OT', pred_len = 96, scale=True):
        # size [seq_len, label_len, pred_len]
        # info
        # init
        self.flag = flag
        self.scale = scale
        self.tasks = tasks
        self.root_path = root_path
        self.data_path = dataset
        self.file_name = file_name
        self.__read_data__()

    def __read_data__(self):
        
        self.scaler = StandardScaler()
        '''
        if self.flag == 'TRAIN' or self.flag == 'VAL':
            string = 'TRAIN'
        else:
            string = 'TEST'
        raw, _ = arff.loadarff(os.path.join(self.root_path, self.data_path + '/' + self.file_name+ '_' + string + '.arff'))
        df_raw = df = pd.DataFrame(raw)
 
        '''
        if self.flag == 'TRAIN' or self.flag == 'VAL':
            string = 'TRAIN'
            raw, _ = arff.loadarff(os.path.join(self.root_path, self.data_path + '/' + self.file_name+ '_' + string + '.arff'))
            df_raw = df = pd.DataFrame(raw)
            cols = list(df_raw.columns)
            cols_data = df_raw[cols[:-1]].astype(float)
            cols_target = df_raw[cols[-1]].astype(float)
            if self.scale:
                data = cols_data
                self.scaler.fit(data.values)
                data = self.scaler.transform(data.values)
            else:
                data = cols_data.values
        else:
            string = 'TEST'
            train_raw, _ =  arff.loadarff(os.path.join(self.root_path, self.data_path + '/' + self.file_name+ '_' + 'TRAIN' + '.arff'))
            raw, _ = arff.loadarff(os.path.join(self.root_path, self.data_path + '/' + self.file_name+ '_' + string + '.arff'))
            df_raw = df = pd.DataFrame(raw)
            train_df_raw = pd.DataFrame(train_raw)
            cols = list(df_raw.columns)
            train_data = train_df_raw[cols[:-1]].astype(float)
            cols_data = df_raw[cols[:-1]].astype(float)
            cols_target = df_raw[cols[-1]].astype(float)
            if self.scale:
                data = train_data
                self.scaler.fit(data.values)
                data = self.scaler.transform(cols_data.values)
            else:
                data = cols_data.values
        
        
       
        #df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
    
        cols = list(df_raw.columns)
        cols_data = df_raw[cols[:-1]].astype(float)
        cols_target = df_raw[cols[-1]].astype(float)
        '''
        if self.scale:
            data = cols_data
            self.scaler.fit(data.values)
            data = self.scaler.transform(data.values)
        else:
            data = cols_data.values
        '''
        #data = cols_data.values
        if self.flag == 'TRAIN':
            self.data_x = data[0:int(0.8*len(data))]
            self.data_y = cols_target.values[0:int(0.8*len(data))]
        elif self.flag == 'VAL':
            self.data_x = data[int(0.8*len(data)):]
            self.data_y = cols_target.values[int(0.8*len(data)):]
        else:
            self.data_x = data
            self.data_y = cols_target.values
        arr, index = np.unique(cols_target.values, return_counts=True)
        self.nb_classes = len(index)
        dic = [i for i in range(len(arr))]
        self.dict = dict(zip(arr, dic))
        #print(self.dict)
        #for i in range(len(index)):
        #    self.dict[index[i]] = i
        #self.data_x = data
        #self.data_y = cols_target.values
 
    def __getitem__(self, index):
        #seq_x_mark = self.data_stamp[s_begin:s_end]
        #seq_y_mark = self.data_stamp[r_begin:r_end]
        x = self.data_x[index]

        if x.shape[0]>1000:
            x = x[0:1000]
        '''
        if self.scale:
            x = (x - np.mean(x))/np.std(x)    
        '''
        y = self.data_y[index]
        y = self.dict[y]
        x = np.expand_dims(x, -1)
        #y = np.expand_dims(y, -1)

        return x, y

    def __len__(self):
        return len(self.data_x)

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)
