"""Master dataloader for MEG-MASC in Gwilliams et al. (2022)"""

import glob
import os
import typing as tp
from pathlib import Path

from osl import preprocessing
from torch.utils.data import ConcatDataset, Dataset

import pnpl.datasets.shafto2014.constants as constants
from pnpl.datasets.dataset_utils import get_cache_path
from pnpl.datasets.shafto2014.subloader import ShaftoSubloader


class Shafto2014(Dataset):
    """Torch Dataset for the entire Schoffelen BIDS dataset."""

    subjects = constants.SUBJECTS

    def __init__(
        self,
        data_path: str,
        l_freq: int,
        h_freq: int,
        notch_freq: int,
        resample_freq: int,
        interpolate_bad_channels: bool,
        window_len: float,
        info: tp.List[str],
        preload: bool = False,
        label: tp.Optional[str] = None,
        preproc_path: tp.Optional[str] = None,
        exclude_subjects: tp.List[str] = [],
        include_subjects: tp.List[str] = [],
    ):
        """
        Load (and preprocess) data and labels.

        exclude* and include* arguments are mutually exclusive.

        Keyword arguments:
        data_path -- path to BIDS dataset
        l_freq -- high-pass filter frequency
        h_freq -- low-pass filter frequency
        notch_freq -- notch filter frequency
        resample_freq -- resampling frequency
        interpolate_bad_channels -- detect bad channels, interpolating them with nearby sensors
        window_len -- length in seconds of recording windows to return in getter
        info -- additional information to provide: sensor_xyz | subject | session | dataset | subject_id
        preload -- load datasets into RAM
        label -- supervised label type to provide: speech | voicing
        preproc_path -- path to preprocessed data directory (to load from or cache to)
        exclude_subjects -- subjects to avoid loading
        include_subjects -- subjects to load (will not load any others)
        """

        assert label is None, "Labels are currently unsupported for this dataset"

        assert (
            len(exclude_subjects) == 0 or len(include_subjects) == 0
        ), "Can specify only one of exclude_subjects or include_subjects"

        if preproc_path is None:
            preproc_path = data_path

        # This dataset is formatted as
        # {data_path}/{task}/sub-{subject}/ses-{task}/meg/sub-{subject}_ses-{task}_task-{task}_meg.fif

        # Can we account for different subjects for different tasks?
        # All subjects *should* have performed all tasks, so we can take the intersection

        self.tasks = [os.path.basename(path) for path in glob.glob(f"{data_path}/*")]
        if "preproc" in self.tasks:
            self.tasks.remove("preproc")

        self.subjects = {
            task: [
                os.path.basename(path).replace("sub-", "")
                for path in glob.glob(f"{data_path}/{task}/sub-*")
            ]
            for task in self.tasks
        }

        # Find the intersection over each dict item in self.subjects
        self.subjects = list(set.intersection(*map(set, self.subjects.values())))

        if len(exclude_subjects) > 0:
            self.subjects = sorted(list(set(self.subjects) - set(exclude_subjects)))
        elif len(include_subjects) > 0:
            self.subjects = sorted(list(set(include_subjects) & set(self.subjects)))

        print("Found subjects:", self.subjects)

        # Find subjects with missing preprocessed cached data
        # Warning: This check cannot determine if the preprocessing configuration was different.
        # If using a new preprocessing config, delete the cache first.
        cache_missing = set()
        for subject in self.subjects:
            for task in self.tasks:
                cache_path = get_cache_path(preproc_path, subject, task, task)
                if not os.path.exists(cache_path):
                    cache_missing.add(subject)

        if len(cache_missing) > 0:
            print(
                "Preprocessed cache missing for subjects",
                cache_missing,
                "Preprocessing them online",
            )

            notch_filter_freqs = " ".join(
                str(f) for f in list(range(notch_freq, h_freq, notch_freq))
            )

            config = """
                preproc:
                - pick_types: {meg: true, ref_meg: false}
                - filter: {l_freq: pl_freq, h_freq: ph_freq, method: iir, iir_params: {order: 5, ftype: butter}}
                - notch_filter: {freqs: pnotch_freqs}
                - resample: {sfreq: presample_freq}
                - bad_channels: {picks: mag, significance_level: 0.1}
                - bad_channels: {picks: grad, significance_level: 0.1}
                - interpolate_bads: {}
            """

            config = (
                config.replace("pl_freq", str(l_freq))
                .replace("ph_freq", str(h_freq))
                .replace("pnotch_freqs", notch_filter_freqs)
                .replace("presample_freq", str(resample_freq))
            )

            if not interpolate_bad_channels:
                config = config.replace("- bad_channels: {picks: mag}", "")
                config = config.replace("- bad_channels: {picks: grad}", "")
                config = config.replace("- interpolate_bads: {}", "")
                config = config.strip()

            preproc_root = Path(preproc_path) / "preproc"
            preproc_root.mkdir(parents=True, exist_ok=True)

            for subject in cache_missing:
                inputs = []
                inputs.extend(
                    sorted(
                        glob.glob(
                            str(data_path)
                            + f"/**/sub-{subject}/ses-*/meg/*_task-*_meg.fif"
                        )
                    )
                )

                preproc_dir = preproc_root / f"sub-{subject}"

                preprocessing.run_proc_batch(
                    config,
                    inputs,
                    outdir=str(preproc_dir),
                    overwrite=True,
                    dask_client=False,
                )

        # Load preprocessed and cached data
        datasets = []
        for subject in self.subjects:
            for task in self.tasks:
                try:
                    dataset = ShaftoSubloader(
                        subject=subject,
                        task=task,
                        window_len=window_len,
                        label=label,
                        info=info,
                        preproc_path=preproc_path,
                        bids_root=data_path,
                        preload=preload,
                    )
                    datasets.append(dataset)
                except Exception as e:
                    print(f"Failed to open {subject} task-{task}", e)

        self.dataset = ConcatDataset(datasets)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset.__getitem__(idx)


if __name__ == "__main__":
    camcan_data = Shafto2014(
        data_path="/data/<anonymised>/<anonymised>/shafto2014/cc700/meg/pipeline/release005/BIDSsep",
        preproc_path="/data/<anonymised>/<anonymised>/shafto2014/cc700/meg/pipeline/release005/BIDSsep",
        l_freq=0.5,
        h_freq=125,
        resample_freq=250,
        notch_freq=50,
        interpolate_bad_channels=True,
        window_len=0.5,
        info=["sensor_xyz", "subject", "dataset"],
        include_subjects=["CC723395"],
    )

    import matplotlib.pyplot as plt

    sample = camcan_data[0]

    data, times = sample["data"], sample["times"]

    for sensor in range(data.shape[0]):
        plt.plot(times, data[sensor, :])
    plt.savefig("shafto2014.png")