# -*- coding: utf-8 -*-
# @File : datahandler.py
# @Author : 王军
# @Time : 2022/10/7 10:05
# @Software : PyCharm
import torch
import numpy as np
import os
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from lib.utils import StandardScaler
def load_st_data(dataset_dir, device,batch_size=64, valid_batch_size= None, test_batch_size=None):
    data = {}
    for category in ['train', 'val', 'test']:
        cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
        data['x_' + category] = cat_data['x']
        data['y_' + category] = cat_data['y']
    scaler = StandardScaler(mean=data['x_train'][..., 0].mean(),
                                std=data['x_train'][..., 0].std())
    # Data format
    for category in ['train', 'val', 'test']:
        data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0])

    valid_batch_size = batch_size if valid_batch_size is None else valid_batch_size
    test_batch_size = batch_size if test_batch_size is None else test_batch_size

    train_dataloader = DataLoader(dataset=st_dataset(data['x_train'],data['y_train'],device=device),
                                  batch_size=batch_size,shuffle=True)
    val_dataloader = DataLoader(dataset=st_dataset(data['x_val'],data['y_val'],device=device),
                                  batch_size=valid_batch_size,shuffle=False)
    test_dataloader = DataLoader(dataset=st_dataset(data['x_test'],data['y_test'],device=device),
                                 batch_size=test_batch_size,shuffle=False)
    return train_dataloader,val_dataloader,test_dataloader,scaler,data['y_test']

class st_dataset(Dataset):
    def __init__(self,X,Y,device='cuda:0'):
        self.X = torch.Tensor(X).to(device)
        self.Y = torch.Tensor(Y).to(device)
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx],self.Y[idx]


if __name__ == '__main__':
    datadir = r'data/METR-LA/12'
    train_dataset = load_st_data(datadir)
    print(train_dataset[0])
