from torch.utils.data import DataLoader, random_split

from pnpl.datasets import Shafto2014

armeni_data = Shafto2014(
    data_path="/data/<anonymised>/<anonymised>/shafto2014/cc700/meg/pipeline/release005/BIDSsep",
    preproc_path="/data/<anonymised>/<anonymised>/shafto2014/cc700/meg/pipeline/release005/BIDSsep",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    info=["sensor_xyz", "subject", "dataset"],
    include_subjects=["CC723395"],
)

train, val, test = random_split(armeni_data, [0.8, 0.1, 0.1])

train_loader = DataLoader(train, batch_size=32, shuffle=True)
val_loader = DataLoader(val, batch_size=32, shuffle=False)
test_loader = DataLoader(test, batch_size=32, shuffle=False)

breakpoint()
