# import pickle
# from torch.utils.data import DataLoader, Dataset
# import pandas as pd
# import numpy as np
# import torch
# from gluonts.dataset.repository.datasets import get_dataset
# from gluonts.dataset.multivariate_grouper import MultivariateGrouper

# def NormMinMax(data):
#     """Min-Max Normalizer.

#     Args:
#       - data: raw data

#     Returns:
#       - norm_data: normalized data
#       - min_val: minimum values (for renormalization)
#       - max_val: maximum values (for renormalization)
#     """

#     min_val = np.min(data, axis=0)
#     data = data - min_val

#     max_val = np.max(data, axis=0)
#     norm_data = data / (max_val + 1e-7)

#     return norm_data, min_val, max_val

# class Forecasting_Dataset(Dataset):
#     def __init__(self, datatype, mode="train", history_length=72):
            
#         self.val_sample = 5
#         if datatype=='electricity':
#             dataset = get_dataset("electricity_nips", regenerate=False)
#             # 168
#             self.history_length = history_length
#             self.pred_length = 24
#             self.test_sample = 7
#         elif datatype=='exchange':
#             # 요거 체크
#             dataset = get_dataset("exchange_rate_nips", regenerate=False)
#             self.history_length = 90
#             self.pred_length = 30
#             self.test_sample = 7
#         elif datatype=='traffic':
#             dataset = get_dataset("traffic_nips", regenerate=False)
#             self.history_length = history_length
#             self.pred_length = 24
#             self.test_sample = 7
#         elif datatype=='solar':
#             dataset = get_dataset("solar_nips", regenerate=False)
#             self.history_length = history_length
#             self.pred_length = 24
#             self.test_sample = 7
#         elif datatype=='wiki':    
#             dataset = get_dataset("wiki2000_nips", regenerate=False)
#             self.history_length = 90
#             self.pred_length = 30
#             self.test_sample = 5
#         elif datatype=='taxi':    
#             dataset = get_dataset("taxi_30min", regenerate=False)
#             self.history_length = 48
#             self.pred_length = 24
#             self.test_sample = 5
        
#         self.test_length= self.pred_length*self.test_sample
#         self.valid_length = self.pred_length*self.val_sample
        
#         self.seq_length = self.history_length + self.pred_length
            
#         train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))
#         dataset_train = train_grouper(dataset.train)
#         self.main_data = np.transpose(dataset_train[0]['target'])
        
#         self.mean_data = np.mean(self.main_data, axis=0)
#         self.std_data = np.std(self.main_data, axis=0)
#         self.std_data = np.clip(self.std_data,a_min=0.001,a_max=1e+7)
#         self.mask_data = np.ones_like(self.main_data)
        
#         if datatype == 'electricity':
#             datafolder = './data/electricity_nips'
#             self.test_length= 24*7
#             self.valid_length = 24*5
                        
#             paths=datafolder+'/data.pkl' 
#             #shape: (T x N)
#             #mask_data is usually filled by 1
#             with open(paths, 'rb') as f:
#                 self.main_data, self.mask_data = pickle.load(f)
#             paths=datafolder+'/meanstd.pkl'
#             with open(paths, 'rb') as f:
#                 self.mean_data, self.std_data = pickle.load(f)
        
#         self.main_data = (self.main_data - self.mean_data) / self.std_data
#         total_length = len(self.main_data)
        
#         if mode == 'train': 
#             start = 0
#             end = total_length - self.seq_length - self.valid_length - self.test_length + 1
#             self.use_index = np.arange(start,end,1)
#         if mode == 'valid': #valid
#             start = total_length - self.seq_length - self.valid_length - self.test_length + self.pred_length
#             end = total_length - self.seq_length - self.test_length + self.pred_length
#             self.use_index = np.arange(start,end,self.pred_length)
#         if mode == 'test': #test
#             start = total_length - self.seq_length - self.test_length + self.pred_length
#             end = total_length - self.seq_length + self.pred_length
#             self.use_index = np.arange(start,end,self.pred_length)
        
#     def __getitem__(self, orgindex):
#         index = self.use_index[orgindex]
#         target_mask = self.mask_data[index:index+self.seq_length].copy()
#         target_mask[-self.pred_length:] = 0. #pred mask for test pattern strategy
#         s = {
#             'observed_data': self.main_data[index:index+self.seq_length],
#             'observed_mask': self.mask_data[index:index+self.seq_length],
#             'gt_mask': target_mask,
#             'timepoints': np.arange(self.seq_length) * 1.0, 
#             'feature_id': np.arange(self.main_data.shape[1]) * 1.0, 
#         }

#         return s
#     def __len__(self):
#         return len(self.use_index)

# def get_dataloader(datatype,device,batch_size=8,history_length=168):
#     dataset = Forecasting_Dataset(datatype,mode='train',history_length=history_length)
#     train_loader = DataLoader(
#         dataset, batch_size=batch_size, shuffle=1)
#     valid_dataset = Forecasting_Dataset(datatype,mode='valid',history_length=history_length)
#     valid_loader = DataLoader(
#         valid_dataset, batch_size=batch_size, shuffle=0)
#     test_dataset = Forecasting_Dataset(datatype,mode='test',history_length=history_length)
#     test_loader = DataLoader(
#         test_dataset, batch_size=batch_size, shuffle=0)

#     scaler = torch.from_numpy(dataset.std_data).to(device).float()
#     mean_scaler = torch.from_numpy(dataset.mean_data).to(device).float()

#     return train_loader, valid_loader, test_loader, scaler, mean_scaler

import pickle
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
import torch
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.multivariate_grouper import MultivariateGrouper

def NormMinMax(data):
    """Min-Max Normalizer.

    Args:
      - data: raw data

    Returns:
      - norm_data: normalized data
      - min_val: minimum values (for renormalization)
      - max_val: maximum values (for renormalization)
    """

    min_val = np.min(data, axis=0)
    data = data - min_val

    max_val = np.max(data, axis=0)
    norm_data = data / (max_val + 1e-7)

    return norm_data, min_val, max_val

class Forecasting_Dataset(Dataset):
    def __init__(self, datatype, mode="train", history_length=72):
            
        self.val_sample = 5
        if datatype=='electricity':
            dataset = get_dataset("electricity_nips", regenerate=False)
            # 168
            self.history_length = history_length
            self.pred_length = 24
            self.test_sample = 7
        elif datatype=='exchange':
            # 요거 체크
            dataset = get_dataset("exchange_rate_nips", regenerate=False)
            self.history_length = history_length
            self.pred_length = 30
            self.test_sample = 7
        elif datatype=='traffic':
            dataset = get_dataset("traffic_nips", regenerate=False)
            self.history_length = history_length
            self.pred_length = 24
            self.test_sample = 7
        elif datatype=='solar':
            dataset = get_dataset("solar_nips", regenerate=False)
            self.history_length = history_length
            self.pred_length = 24
            self.test_sample = 7
        elif datatype=='wiki':    
            dataset = get_dataset("wiki2000_nips", regenerate=False)
            self.history_length = 90
            self.pred_length = 30
            self.test_sample = 5
        elif datatype=='taxi':    
            dataset = get_dataset("taxi_30min", regenerate=False)
            self.history_length = 48
            self.pred_length = 24
            self.test_sample = 56
        
        self.test_length= self.pred_length*self.test_sample
        self.valid_length = self.pred_length*self.val_sample
        
        self.seq_length = self.history_length + self.pred_length
            
        train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))
        dataset_train = train_grouper(dataset.train)
        self.main_data = np.transpose(dataset_train[0]['target'])

        test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), 
                                        max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))

        record_hist = self.seq_length
        ###########
        # record_hist = 192
        ###########
        dataset_test = test_grouper(dataset.test)
        result_test = []
        for test in dataset_test:
            result_test.append(np.expand_dims(test['target'],0))
        dataset_test = np.concatenate(result_test,axis=0).transpose(0,2,1)[:,-record_hist:]
        
        concat_result = []
        concat_result.append(dataset_test[0])
        for cand in dataset_test[1:]:
            concat_result.append(cand[-self.pred_length:])
        concat_result = np.concatenate(concat_result, axis=0)
        self.main_data = np.concatenate([self.main_data[:-int(record_hist-self.pred_length)],concat_result],axis=0)

        self.mean_data = np.mean(self.main_data, axis=0)
        self.std_data = np.std(self.main_data, axis=0)
        self.std_data = np.clip(self.std_data,a_min=0.001,a_max=1e+7)
        self.mask_data = np.ones_like(self.main_data)
        
        if datatype == 'electricity':
            datafolder = './data/electricity_nips'
            self.test_length= 24*7
            self.valid_length = 24*5
                        
            paths=datafolder+'/data.pkl' 
            #shape: (T x N)
            #mask_data is usually filled by 1
            with open(paths, 'rb') as f:
                self.main_data, self.mask_data = pickle.load(f)
            paths=datafolder+'/meanstd.pkl'
            with open(paths, 'rb') as f:
                self.mean_data, self.std_data = pickle.load(f)
        
        self.main_data = (self.main_data - self.mean_data) / self.std_data
        total_length = len(self.main_data)
        
        if mode == 'train': 
            start = 0
            end = total_length - self.seq_length - self.valid_length - self.test_length + 1
            self.use_index = np.arange(start,end,1)
        if mode == 'valid': #valid
            start = total_length - self.seq_length - self.valid_length - self.test_length + self.pred_length
            end = total_length - self.seq_length - self.test_length + self.pred_length
            self.use_index = np.arange(start,end,self.pred_length)
        if mode == 'test': #test
            start = total_length - self.seq_length - self.test_length + self.pred_length
            end = total_length - self.seq_length + self.pred_length
            self.use_index = np.arange(start,end,self.pred_length)
        
    def __getitem__(self, orgindex):
        index = self.use_index[orgindex]
        target_mask = self.mask_data[index:index+self.seq_length].copy()
        target_mask[-self.pred_length:] = 0. #pred mask for test pattern strategy
        s = {
            'observed_data': self.main_data[index:index+self.seq_length],
            'observed_mask': self.mask_data[index:index+self.seq_length],
            'gt_mask': target_mask,
            'timepoints': np.arange(self.seq_length) * 1.0, 
            'feature_id': np.arange(self.main_data.shape[1]) * 1.0, 
        }

        return s
    def __len__(self):
        return len(self.use_index)

def get_dataloader(datatype,device,batch_size=8,history_length=168):
    dataset = Forecasting_Dataset(datatype,mode='train',history_length=history_length)
    train_loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=1)
    valid_dataset = Forecasting_Dataset(datatype,mode='valid',history_length=history_length)
    valid_loader = DataLoader(
        valid_dataset, batch_size=batch_size, shuffle=0)
    test_dataset = Forecasting_Dataset(datatype,mode='test',history_length=history_length)
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=0)

    scaler = torch.from_numpy(dataset.std_data).to(device).float()
    mean_scaler = torch.from_numpy(dataset.mean_data).to(device).float()

    return train_loader, valid_loader, test_loader, scaler, mean_scaler