from torch.utils.data import DataLoader
from .data_info import batch_dict, dataclass_dict
from torch.utils.data.distributed import DistributedSampler

def data_provider(args, flag):
    Data = dataclass_dict[args.data]
    if flag == 'test':
        shuffle_flag = False
        drop_last = True
        batch_size =  batch_dict[args.data]
    elif flag == 'val':
        shuffle_flag = True
        drop_last = True
        batch_size = batch_dict[args.data]
    else:
        shuffle_flag = True
        drop_last = True
        batch_size = args.batch_size
    print(batch_size)
    data_set = Data(
            root_path = args.root_path,
            data_path = args.data_path,
            flag = flag,
            size = [args.seq_len, args.label_len, args.pred_len],
            batch_size = batch_size,
        )




    if args.distributed:
        data_sampler = DistributedSampler(data_set, shuffle = shuffle_flag)
        data_loader = DataLoader(dataset=data_set,
                          batch_size=int(batch_size / args.world_size),
                          shuffle=False,
                          num_workers=int(args.num_workers / args.world_size),
                          sampler=data_sampler,
                          pin_memory=True,
                          drop_last=drop_last)
        data_set = data_sampler
        print(flag, len(data_set), len(data_loader))
    else:
        # if args.data == "traffic":
        #     batch_size = 12
        if args.data == 'electricity' and args.split_num == 321:
            batch_size = 16
        data_loader = DataLoader(
            data_set,
            batch_size=batch_size,
            shuffle=shuffle_flag,
            num_workers=0,
            drop_last=drop_last)
        print(flag, len(data_set), len(data_loader))
    return data_set, data_loader
