"""Master dataloader for MEG-MASC in Gwilliams et al. (2022)"""

import glob
import os
import typing as tp
import warnings
from pathlib import Path

from osl import preprocessing
from torch.utils.data import ConcatDataset, Dataset

import pnpl.datasets.schoffelen2019.constants as constants
from pnpl.datasets.dataset_utils import get_cache_path
from pnpl.datasets.schoffelen2019.subloader import SchoffelenSubloader


class Schoffelen2019(Dataset):
    """Torch Dataset for the entire Schoffelen BIDS dataset."""

    subjects = constants.SUBJECTS

    def __init__(
        self,
        data_path: str,
        l_freq: int,
        h_freq: int,
        notch_freq: int,
        resample_freq: int,
        interpolate_bad_channels: bool,
        window_len: float,
        info: tp.List[str],
        preload: bool = False,
        label: tp.Optional[str] = None,
        preproc_path: tp.Optional[str] = None,
        exclude_subjects: tp.List[str] = [],
        include_subjects: tp.List[str] = [],
    ):
        """
        Load (and preprocess) data and labels.

        exclude* and include* arguments are mutually exclusive.

        Keyword arguments:
        data_path -- path to BIDS dataset
        l_freq -- high-pass filter frequency
        h_freq -- low-pass filter frequency
        notch_freq -- notch filter frequency
        resample_freq -- resampling frequency
        interpolate_bad_channels -- detect bad channels, interpolating them with nearby sensors
        window_len -- length in seconds of recording windows to return in getter
        info -- additional information to provide: sensor_xyz | subject | session | dataset | subject_id
        preload -- load datasets into RAM
        label -- supervised label type to provide: speech | voicing
        preproc_path -- path to preprocessed data directory (to load from or cache to)
        exclude_subjects -- subjects to avoid loading
        include_subjects -- subjects to load (will not load any others)
        """

        assert label is None, "Labels are currently unsupported for this dataset"

        assert (
            len(exclude_subjects) == 0 or len(include_subjects) == 0
        ), "Can specify only one of exclude_subjects or include_subjects"

        if preproc_path is None:
            preproc_path = data_path

        # Find subjects
        self.subjects = [
            os.path.basename(path).replace("sub-", "")
            for path in glob.glob(f"{data_path}/sub-*")
        ]
        if set(self.subjects) != set(constants.SUBJECTS):
            warnings.warn(
                f"Your dataset's subjects do not match the expected subjects for this dataset. Expecting {constants.SUBJECTS} but found {self.subjects}."
            )

        if len(exclude_subjects) > 0:
            self.subjects = sorted(list(set(self.subjects) - set(exclude_subjects)))
        elif len(include_subjects) > 0:
            self.subjects = sorted(list(set(include_subjects) & set(self.subjects)))

        # Find subjects with missing preprocessed cached data
        # Warning: This check cannot determine if the preprocessing configuration was different.
        # If using a new preprocessing config, delete the cache first.
        cache_missing = set()
        for subject in self.subjects:
            for task in constants.TASKS:
                if (subject[0] == "A" and task == "visual") or (
                    subject[0] == "V" and task == "auditory"
                ):
                    continue

                cache_path = get_cache_path(preproc_path, subject, None, task)
                if not os.path.exists(cache_path):
                    cache_missing.add(subject)

        if len(cache_missing) > 0:
            print(
                "Preprocessed cache missing for subjects",
                cache_missing,
                "Preprocessing them online",
            )

            notch_filter_freqs = " ".join(
                str(f) for f in list(range(notch_freq, h_freq, notch_freq))
            )

            config = """
                preproc:
                - pick_types: {meg: true, ref_meg: false}
                - filter: {l_freq: pl_freq, h_freq: ph_freq, method: iir, iir_params: {order: 5, ftype: butter}}
                - notch_filter: {freqs: pnotch_freqs}
                - resample: {sfreq: presample_freq}
                - bad_channels: {picks: mag}
                - interpolate_bads: {}
            """

            config = (
                config.replace("pl_freq", str(l_freq))
                .replace("ph_freq", str(h_freq))
                .replace("pnotch_freqs", notch_filter_freqs)
                .replace("presample_freq", str(resample_freq))
            )

            if not interpolate_bad_channels:
                config = config.replace("- bad_channels: {picks: mag}", "")
                config = config.replace("- interpolate_bads: {}", "")
                config = config.strip()

            preproc_root = Path(preproc_path) / "preproc"
            preproc_root.mkdir(parents=True, exist_ok=True)

            for subject in cache_missing:
                inputs = []
                inputs.extend(
                    sorted(
                        glob.glob(
                            str(data_path) + f"/sub-{subject}/meg/*_task-*_meg.ds"
                        )
                    )
                )

                preproc_dir = preproc_root / f"sub-{subject}"

                preprocessing.run_proc_batch(
                    config,
                    inputs,
                    outdir=str(preproc_dir),
                    overwrite=True,
                    dask_client=False,
                )

        # Load preprocessed and cached data
        datasets = []
        for subject in self.subjects:
            for task in constants.TASKS:
                if (subject[0] == "A" and task == "visual") or (
                    subject[0] == "V" and task == "auditory"
                ):
                    continue

                try:
                    dataset = SchoffelenSubloader(
                        subject=subject,
                        task=task,
                        window_len=window_len,
                        label=label,
                        info=info,
                        preproc_path=preproc_path,
                        bids_root=data_path,
                        preload=preload,
                    )
                    datasets.append(dataset)
                except Exception as e:
                    print(f"Failed to open {subject} task-{task}", e)

        self.dataset = ConcatDataset(datasets)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset.__getitem__(idx)


if __name__ == "__main__":
    schoffelen_data = Schoffelen2019(
        data_path="/data/<anonymised>/<anonymised>/schoffelen2019",
        preproc_path="/data/<anonymised>/<anonymised>/schoffelen2019",
        l_freq=0.5,
        h_freq=125,
        resample_freq=250,
        notch_freq=50,
        interpolate_bad_channels=True,
        window_len=0.5,
        info=["sensor_xyz", "subject", "dataset"],
        include_subjects=["V1117"],
    )

    breakpoint()