import torch
from torch.utils.data import DataLoader, random_split

from pnpl.datasets import Armeni2022

armeni_data = Armeni2022(
    data_path="/data/<anonymised>/<anonymised>/armeni2022",
    preproc_path="/data/<anonymised>/<anonymised>/armeni2022",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    label="speech",
    info=["subject_id", "dataset"],
    preload=False,
)

# Create canonical splits
train, val, test = random_split(armeni_data, [0.8, 0.1, 0.1])

# Stack tensors
train = torch.stack([x for x in train])
val = torch.stack([x for x in val])
test = torch.stack([x for x in test])

# Save to disk
torch.save(train, "/data/<anonymised>/<anonymised>/armeni2022/canonical/speech/train.pt")
torch.save(val, "/data/<anonymised>/<anonymised>/armeni2022/canonical/speech/val.pt")
torch.save(test, "/data/<anonymised>/<anonymised>/armeni2022/canonical/speech/test.pt")