from torch.utils.data import DataLoader, random_split

from pnpl.dataloaders import MultiDataLoader
from pnpl.datasets import Armeni2022, Gwilliams2022, Schoffelen2019

armeni_data = Armeni2022(
    data_path="/data/<anonymised>/<anonymised>/armeni2022",
    preproc_path="/data/<anonymised>/<anonymised>/armeni2022",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    label="speech",
    info=["sensor_xyz", "subject", "session", "dataset"],
    include_subjects=["001", "003"],
    include_sessions={"001": ["001"], "003": ["004", "010"]},
)

armeni_train, armeni_val, armeni_test = random_split(armeni_data, [0.8, 0.1, 0.1])
armeni_train_loader = DataLoader(armeni_train, batch_size=32, shuffle=True)
armeni_val_loader = DataLoader(armeni_val, batch_size=32, shuffle=False)
armeni_test_loader = DataLoader(armeni_test, batch_size=32, shuffle=False)

gwilliams_data = Gwilliams2022(
    data_path="/data/<anonymised>/<anonymised>/gwilliams2022",
    preproc_path="/data/<anonymised>/<anonymised>/gwilliams2022",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    label="speech",
    info=["sensor_xyz", "subject", "session", "dataset"],
    include_subjects=["27"],
)

gwilliams_train, gwilliams_val, gwilliams_test = random_split(
    gwilliams_data, [0.8, 0.1, 0.1]
)
gwilliams_train_loader = DataLoader(gwilliams_train, batch_size=32, shuffle=True)
gwilliams_val_loader = DataLoader(gwilliams_val, batch_size=32, shuffle=False)
gwilliams_test_loader = DataLoader(gwilliams_test, batch_size=32, shuffle=False)

schoffelen_data = Schoffelen2019(
    data_path="/data/<anonymised>/<anonymised>/schoffelen2019",
    preproc_path="/data/<anonymised>/<anonymised>/schoffelen2019",
    l_freq=0.5,
    h_freq=125,
    resample_freq=250,
    notch_freq=50,
    interpolate_bad_channels=True,
    window_len=0.5,
    info=["sensor_xyz", "subject", "dataset"],
    include_subjects=["A2002", "V1117"],
)

schoffelen_train, schoffelen_val, schoffelen_test = random_split(
    schoffelen_data, [0.8, 0.1, 0.1]
)
schoffelen_train_loader = DataLoader(schoffelen_train, batch_size=32, shuffle=True)
schoffelen_val_loader = DataLoader(schoffelen_val, batch_size=32, shuffle=False)
schoffelen_test_loader = DataLoader(schoffelen_test, batch_size=32, shuffle=False)

train_loader = MultiDataLoader(
    [armeni_train_loader, gwilliams_train_loader, schoffelen_train_loader], shuffle=True
)

val_loader = MultiDataLoader(
    [armeni_val_loader, gwilliams_val_loader, schoffelen_val_loader], shuffle=False
)

test_loader = MultiDataLoader(
    [armeni_test_loader, gwilliams_test_loader, schoffelen_test_loader], shuffle=False
)

breakpoint()
