
import pdb
import numpy as np

from data_process.data_loader import Dataset_Custom5
from torch.utils.data import DataLoader
import torch

data_dict = {

    'custom5':Dataset_Custom5,

}

def _custom_collate_fn(batch):
    batch_x = torch.tensor([item[0] for item in batch])
    batch_y = torch.tensor([item[1] for item in batch])
    timestamp = [item[2] for item in batch]
    timestamp = np.repeat(timestamp, 7, axis=0)
    return batch_x, batch_y, timestamp, None

def _custom_collate_fn2(batch):
    batch_x = torch.tensor([item[0] for item in batch])
    batch_y = torch.tensor([item[1] for item in batch])
    size = batch_x.shape[0]
    timestamp = [item[2] for item in batch]
    #timestamp = np.repeat(timestamp, size, axis=0)
    return batch_x, batch_y, timestamp, None

def data_provider(args, flag):
    print(args.data)
    Data = data_dict[args.data]
    timeenc = 0 if args.embed != 'timeF' else 1

    if flag == 'test':
        shuffle_flag = False
        drop_last = False
        # batch_size = 1
        batch_size = args.batch_size
        freq = args.freq

    else:
        shuffle_flag = True
        drop_last = False
        batch_size = args.batch_size
        freq = args.freq

    if Data == Dataset_Custom5 or Data==Dataset_Custom7:
        data_loader = DataLoader(
            data_set,
            batch_size=batch_size,
            shuffle=shuffle_flag,
            drop_last=drop_last,
            collate_fn=_custom_collate_fn2
        )
    else:
        data_loader = DataLoader(
            data_set,
            batch_size=batch_size,
            shuffle=shuffle_flag,
            drop_last=drop_last)
    return data_set, data_loader