from data_provider.data_loader import UCIloader, SHARloader, EMGloader, OPPloader, DSADSloader, USCHADloader, PAMAPloader, WESADloader, EEGloader, UEAloader
from data_provider.uea import collate_fn
from torch.utils.data import DataLoader
import os
import torch

data_dict = {
    "UCIHAR": UCIloader,
    "SHAR": SHARloader,
    "EMG": EMGloader,
    "OPP": OPPloader,
    "DSADS": DSADSloader,
    "USCHAD": USCHADloader,
    "PAMAP": PAMAPloader,
    "WESAD": WESADloader,
    "EEG": EEGloader,
    "UEA": UEAloader,
}


def data_provider(args, flag):
    Data = data_dict[args.data]
    shuffle_flag = False
    drop_last = False
    
    batch_size = args.batch_size
    data_path = os.path.join(args.root_path, args.data)
    if args.data == "UEA":
        data_path = args.root_path
        flag = "TEST" if flag == "VAL" else flag
        
    data_set = Data(
        args=args,
        data_path=data_path,
        flag=flag,
    )
    if flag == "TRAIN":
        # balance the data
        print("balance = ", args.balance)
        if args.balance == True:
            sample_weights = data_set.get_sample_weights()
            sampler = torch.utils.data.sampler.WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
            data_loader = DataLoader(
                data_set,
                batch_size=batch_size,
                shuffle=shuffle_flag,
                num_workers=args.num_workers,
                drop_last=drop_last,
                sampler=sampler,
                collate_fn=lambda x: collate_fn(x, max_len=args.len_seq)
            )
            return data_set, data_loader
        else:
            data_loader = DataLoader(
                data_set,
                batch_size=batch_size,
                shuffle=True,
                num_workers=args.num_workers,
                drop_last=drop_last,
                collate_fn=lambda x: collate_fn(x, max_len=args.len_seq)
            )
            return data_set, data_loader
    else:
        data_loader = DataLoader(
            data_set,
            batch_size=batch_size,
            shuffle=shuffle_flag,
            num_workers=args.num_workers,
            drop_last=drop_last,
            collate_fn=lambda x: collate_fn(x, max_len=args.len_seq)
        )
        return data_set, data_loader
