from torch.utils.data import DataLoader

from pnpl.datasets import Gwilliams2022

gwilliams_train = Gwilliams2022(
    data_path="/data/<anonymised>/<anonymised>/gwilliams2022",
    preproc_path="/data/<anonymised>/<anonymised>/gwilliams2022",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    label="speech",
    info=["sensor_xyz", "subject", "session", "dataset"],
    include_subjects=["24", "25", "26", "27"],
)

gwilliams_val = Gwilliams2022(
    data_path="/data/<anonymised>/<anonymised>/gwilliams2022",
    preproc_path="/data/<anonymised>/<anonymised>/gwilliams2022",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    label="speech",
    info=["sensor_xyz", "subject", "session", "dataset"],
    include_subjects=["22", "23"],
)

gwilliams_test = Gwilliams2022(
    data_path="/data/<anonymised>/<anonymised>/gwilliams2022",
    preproc_path="/data/<anonymised>/<anonymised>/gwilliams2022",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    label="speech",
    info=["sensor_xyz", "subject", "session", "dataset"],
    include_subjects=["21", "22"],
)

train_loader = DataLoader(gwilliams_train, batch_size=32, shuffle=True)
val_loader = DataLoader(gwilliams_val, batch_size=32, shuffle=False)
test_loader = DataLoader(gwilliams_test, batch_size=32, shuffle=False)

breakpoint()
