"""Load data from a specific recording in Schoffelen et al. (2019)"""

import typing as tp

import mne
import numpy as np
import torch
from torch.utils.data import Dataset

import pnpl.datasets.schoffelen2019.constants as constants
from pnpl.datasets.dataset_utils import read_cache_file


class SchoffelenSubloader(Dataset):
    """Dataset for a specific subject, session, and task."""

    def __init__(
        self,
        subject: str,
        task: str,
        window_len: float,
        label: tp.Optional[str],
        info: tp.List[str],
        preproc_path: str,
        bids_root: str,
        preload: bool,
    ):
        """
        Load specified recording from cache.
        """

        raw = read_cache_file(
            preproc_path=preproc_path,
            subject=subject,
            session=None,
            task=task,
            preload=preload,
        )

        existing_channel_names = raw.info['ch_names']
        self.channels = existing_channel_names
        suffix = existing_channel_names[0].split('-')[-1]
        picks_channel_names = [ch + '-' + suffix for ch in constants.CHANNELS]
        raw = raw.pick_channels(picks_channel_names)

        self.subject = subject
        self.label = label

        sfreq = float(raw.info["sfreq"])
        self.samples_per_slice = int(sfreq * window_len)
        self.num_slices = int(len(raw) / self.samples_per_slice)

        # Add optional additional information
        self.info = {}
        if "dataset" in info:
            self.info["dataset"] = "schoffelen2019"
        if "subject" in info:
            self.info["subject"] = subject
        if "subject_id" in info:
            self.info["subject_id"] = constants.SUBJECTS.index(subject)
        if "session" in info:
            self.info["session"] = "None"
        if "sensor_xyz" in info:
            sensor_positions = []
            for ch in raw.info["chs"]:
                pos = ch["loc"][:3]  # Extracts the first three elements: X, Y, Z
                sensor_positions.append(pos)
            sensor_positions = torch.tensor(np.array(sensor_positions))
            self.info["sensor_xyz"] = sensor_positions

        self.raw = raw

    def __len__(self):
        return self.num_slices

    def __getitem__(self, idx):
        return_dict = {}

        if self.label is None:
            data, times = self.raw[
                :, idx * self.samples_per_slice : (idx + 1) * self.samples_per_slice
            ]

        return_dict["data"] = data
        return_dict["times"] = times

        return_dict["info"] = self.info

        return return_dict
