from data_provider.data_loader import (Dataset_mekong_target_station,
                                       Dataset_mekong_for_gnn, Dataset_mekong_stations_list)
from torch.utils.data import DataLoader, Dataset

data_dict = {
    'diff_range_for_target': Dataset_mekong_target_station,
    'for_gnn': Dataset_mekong_for_gnn,
    'all_in_list': Dataset_mekong_stations_list,
}

def data_provider(args, flag, verbose=True, data_loader_name=None, stations_list=None):
    if data_loader_name is None:
        Data = data_dict[args.data]
    else:
        Data = data_dict[data_loader_name]
    timeenc = 0 if args.embed != 'timeF' else 1

    shuffle_flag = False if (flag == 'test' or flag == 'TEST') else True
    drop_last = False
    batch_size = args.batch_size
    freq = args.freq

    root_path = args.dataset_path + args.target + '/' + args.data_time_path + args.data_root_path
    if data_loader_name == 'all_in_list':
        if stations_list is None:
            raise ValueError('stations_list is empty')
        data_set = Data(
            args=args,
            root_path=root_path,
            data_path=args.data_path,
            flag=flag,
            size=[args.seq_len, args.label_len, args.pred_len],
            features=args.features,
            target=args.target,
            scaler_dict=args.scaler_dict,
            timeenc=timeenc,
            freq=freq,
            stations_list=stations_list
        )
    else:
        data_set = Data(
            args=args,
            root_path=root_path,
            data_path=args.data_path,
            flag=flag,
            size=[args.seq_len, args.label_len, args.pred_len],
            features=args.features,
            target=args.target,
            scaler_dict=args.scaler_dict,
            timeenc=timeenc,
            freq=freq
        )
    if verbose:
        print(flag, len(data_set))
    data_loader = DataLoader(
        data_set,
        batch_size=batch_size,
        shuffle=shuffle_flag,
        num_workers=args.num_workers,
        drop_last=drop_last)
    return data_set, data_loader
