from torch.utils.data import DataLoader, random_split

from pnpl.datasets import Armeni2022

armeni_data = Armeni2022(
    data_path="/data/<anonymised>/<anonymised>/armeni2022",
    preproc_path="/data/<anonymised>/<anonymised>/armeni2022",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    label="speech",
    info=["sensor_xyz", "subject", "session", "dataset"],
    include_subjects=["001"],
    include_sessions={"001": ["001"]},
    preload=True,
)

train, val, test = random_split(armeni_data, [0.8, 0.1, 0.1])

train_loader = DataLoader(train, batch_size=32, shuffle=True)
val_loader = DataLoader(val, batch_size=32, shuffle=False)
test_loader = DataLoader(test, batch_size=32, shuffle=False)

breakpoint()
