"""Master dataloader for Armeni et al. (2022)"""

import collections
import glob
import os
import typing as tp
import warnings
from pathlib import Path

from osl import preprocessing
from torch.utils.data import ConcatDataset, Dataset

import pnpl.datasets.armeni2022.constants as constants
from pnpl.datasets.armeni2022.subloader import ArmeniSubloader
from pnpl.datasets.dataset_utils import get_cache_path


class Armeni2022(Dataset):
    """Torch Dataset for the entire Armeni BIDS dataset."""

    subjects = constants.SUBJECTS

    def __init__(
        self,
        data_path: str,
        l_freq: int,
        h_freq: int,
        notch_freq: int,
        resample_freq: int,
        interpolate_bad_channels: bool,
        window_len: float,
        label: tp.Optional[str],
        info: tp.List[str],
        label_delay: float = 0.0,
        preload: bool = False,
        preproc_path: tp.Optional[str] = None,
        exclude_subjects: tp.List[str] = [],
        exclude_sessions: tp.Dict[str, tp.List[str]] = {},
        include_subjects: tp.List[str] = [],
        include_sessions: tp.Dict[str, tp.List[str]] = {},
    ):
        """
        Load (and preprocess) data and labels.

        exclude* and include* arguments are mutually exclusive.

        Keyword arguments:
        data_path -- path to BIDS dataset
        l_freq -- high-pass filter frequency
        h_freq -- low-pass filter frequency
        notch_freq -- notch filter frequency
        resample_freq -- resampling frequency
        interpolate_bad_channels -- detect bad channels, interpolating them with nearby sensors
        window_len -- length in seconds of recording windows to return in getter
        label -- supervised label type to provide: speech | voicing | phoneme
        info -- additional information to provide: sensor_xyz | subject | session | dataset | subject_id
        label_delay -- delay labels by this amount of time in seconds
        preload -- load datasets into RAM
        preproc_path -- path to preprocessed data directory (to load from or cache to)
        exclude_subjects -- subjects to avoid loading
        exclude_sessions -- sessions for specific subjects to avoid loading
        include_subjects -- subjects to load (will not load any others)
        include_sessions -- sessions for specific subjects to load (will not load any others)
        """

        include_sessions = collections.defaultdict(list, include_sessions)
        exclude_sessions = collections.defaultdict(list, exclude_sessions)

        assert label in ["speech", "voicing", "phoneme", None], f"Unsupported label type {label}"

        assert (
            len(exclude_subjects) == 0 or len(include_subjects) == 0
        ), "Can specify only one of exclude_subjects or include_subjects"

        assert (
            len(exclude_sessions) == 0 or len(include_sessions) == 0
        ), "Can specify only one of exclude_sessions or include_sessions"

        if exclude_sessions:
            assert (
                all(isinstance(x, list) for x in exclude_sessions.values())
            ), "exclude_sessions must be a dictionary of lists"
        if include_sessions:
            assert (
                all(isinstance(x, list) for x in include_sessions.values())
            ), "include_sessions must be a dictionary of lists"

        if preproc_path is None:
            preproc_path = data_path

        # Find subjects
        self.subjects = [
            os.path.basename(path).replace("sub-", "")
            for path in glob.glob(f"{data_path}/sub-*")
        ]
        if set(self.subjects) != set(constants.SUBJECTS):
            warnings.warn(
                f"Your dataset's subjects do not match the expected subjects for this dataset. Expecting {constants.SUBJECTS} but found {self.subjects}."
            )

        if len(exclude_subjects) > 0:
            self.subjects = sorted(list(set(self.subjects) - set(exclude_subjects)))
        elif len(include_subjects) > 0:
            self.subjects = sorted(list(set(include_subjects) & set(self.subjects)))

        # Find sessions
        self.sessions = {}
        for subject in self.subjects:
            sessions = [
                os.path.basename(path).replace("ses-", "")
                for path in glob.glob(f"{data_path}/sub-{subject}/ses-*")
            ]
            if set(sessions) != set(constants.SESSIONS):
                warnings.warn(
                    f"Your dataset's sessions do not match the expected sessions for subject {subject}. Expecting {constants.SESSIONS} but found {sessions}."
                )

            if len(exclude_sessions) > 0:
                sessions = sorted(list(set(sessions) - set(exclude_sessions[subject])))
            elif len(include_sessions) > 0:
                sessions = sorted(list(set(include_sessions[subject]) & set(sessions)))
            self.sessions[subject] = sessions

        # Find subjects with missing preprocessed cached data
        # Warning: This check cannot determine if the preprocessing configuration was different.
        # If using a new preprocessing config, delete the cache first.
        cache_missing = set()
        for subject, sessions in self.sessions.items():
            for session in sessions:
                for task in constants.TASKS:
                    cache_path = get_cache_path(preproc_path, subject, session, task)
                    if not os.path.exists(cache_path):
                        cache_missing.add(subject)

        if len(cache_missing) > 0:
            print(
                "Preprocessed cache missing for subjects",
                cache_missing,
                "Preprocessing them online",
            )

            notch_filter_freqs = " ".join(
                str(f) for f in list(range(notch_freq, h_freq, notch_freq))
            )

            config = """
                preproc:
                - pick_types: {meg: true, ref_meg: false}
                - filter: {l_freq: pl_freq, h_freq: ph_freq, method: iir, iir_params: {order: 5, ftype: butter}}
                - notch_filter: {freqs: pnotch_freqs}
                - resample: {sfreq: presample_freq}
                - bad_channels: {picks: mag}
                - interpolate_bads: {}
            """

            config = (
                config.replace("pl_freq", str(l_freq))
                .replace("ph_freq", str(h_freq))
                .replace("pnotch_freqs", notch_filter_freqs)
                .replace("presample_freq", str(resample_freq))
            )

            if not interpolate_bad_channels:
                config = config.replace("- bad_channels: {picks: mag}", "")
                config = config.replace("- interpolate_bads: {}", "")
                config = config.strip()

            preproc_root = Path(preproc_path) / "preproc"
            preproc_root.mkdir(parents=True, exist_ok=True)

            for subject in cache_missing:
                inputs = []
                for sess in constants.SESSIONS:
                    for task in constants.TASKS:
                        inputs.append(
                            str(
                                data_path
                                / f"sub-{subject}/ses-{sess}/meg/sub-{subject}_ses-{sess}_task-{task}_meg.ds"
                            )
                        )

                preproc_dir = preproc_root / f"sub-{subject}"

                preprocessing.run_proc_batch(
                    config,
                    inputs,
                    outdir=str(preproc_dir),
                    overwrite=True,
                    dask_client=False,
                )

        # Load preprocessed and cached data
        datasets = []
        for subject, sessions in self.sessions.items():
            for session in sessions:
                for task in constants.TASKS:
                    dataset = ArmeniSubloader(
                        subject=subject,
                        task=task,
                        session=session,
                        window_len=window_len,
                        label=label,
                        info=info,
                        preproc_path=preproc_path,
                        bids_root=data_path,
                        preload=preload,
                        label_delay=label_delay,
                    )
                    datasets.append(dataset)

        self.dataset = ConcatDataset(datasets)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset.__getitem__(idx)


if __name__ == "__main__":
    armeni_data = Armeni2022(
        data_path="/data/<anonymised>/<anonymised>/armeni2022",
        preproc_path="/data/<anonymised>/<anonymised>/armeni2022",
        l_freq=0.5,
        h_freq=125,
        resample_freq=250,
        notch_freq=50,
        interpolate_bad_channels=True,
        window_len=0.5,
        label="speech",
        info=["sensor_xyz", "subject_id", "session", "dataset"],
        include_subjects=["001", "003"],
        include_sessions={"001": ["001"], "003": ["001"]},
    )
    breakpoint()
