"""Load data from a specific recording in Armeni et al. (2022)"""

import typing as tp

import numpy as np
import torch
from torch.utils.data import Dataset

import pnpl.datasets.armeni2022.constants as constants
from pnpl.datasets.dataset_utils import read_cache_file, read_events_file, arpabet, arpabet_voiceless

def get_phoneme(description: str) -> tp.Optional[str]:
    """Take a phoneme_onset descriptor (e.g. DH0) and returns the ARPABET descriptor."""

    if description in arpabet:
        return description
    elif (
        len(description) == 3
        and description[2].isnumeric()
        and description[:2] in arpabet
    ):
        return description[:2]
    else:
        return None


def get_speech_labels(
    events,
    raw,
    offset,
):
    """Use events to determine speech labels."""

    sample_freq = raw.info["sfreq"]
    offset_samples = int(sample_freq * offset)

    phoneme_events = events[["word_onset" in c for c in list(events["type"])]]
    labels = np.zeros(len(raw))
    for _, phoneme_event in phoneme_events.iterrows():
        # Decision rule: if event is an "sp" mark it explicitly as silence
        onset = float(phoneme_event["onset"])
        duration = float(phoneme_event["duration"])
        t_start = (
            int(onset * sample_freq) + offset_samples
        )  # Delay labels so they occur at same time as brain response
        t_end = int((onset + duration) * sample_freq) + offset_samples

        labels[t_start : t_end + 1] = 0.0 if phoneme_event["value"] == "sp" else 1.0

    return labels


def get_voicing_labels(
    events,
    raw,
    offset,
) -> tp.Tuple[tp.List[int], tp.List[float]]:
    """Use events to determine aligned phoneme onsets and their voicing labels."""

    sample_freq = raw.info["sfreq"]
    offset_samples = int(sample_freq * offset)

    phoneme_events = events[["phoneme_onset" in c for c in list(events["type"])]]

    phoneme_onsets = []
    labels = []
    for _, phoneme_event in phoneme_events.iterrows():
        value = get_phoneme(phoneme_event["value"])
        if value is not None:
            onset = float(phoneme_event["onset"])

            if onset > 0:
                t_start = int(onset * sample_freq) + offset_samples

                if value in arpabet_voiceless:
                    labels.append(0.0)
                else:
                    labels.append(1.0)

                phoneme_onsets.append(t_start)

    return phoneme_onsets, labels

def get_phoneme_labels(
    events,
    raw,
    offset,
) -> tp.Tuple[tp.List[int], tp.List[str]]:
    """Use events to determine aligned phoneme onsets and their phoneme labels."""

    sample_freq = raw.info["sfreq"]
    offset_samples = int(sample_freq * offset)

    phoneme_events = events[["phoneme_onset" in c for c in list(events["type"])]]

    phoneme_onsets = []
    labels = []
    for _, phoneme_event in phoneme_events.iterrows():
        value = get_phoneme(phoneme_event["value"])
        if value is not None:
            onset = float(phoneme_event["onset"])
            if onset > 0:
                t_start = int(onset * sample_freq) + offset_samples
                labels.append(arpabet.index(value))
                phoneme_onsets.append(t_start)

    return phoneme_onsets, labels

class ArmeniSubloader(Dataset):
    """Dataset for a specific subject, session, and task."""

    def __init__(
        self,
        subject: str,
        task: str,
        session: str,
        window_len: float,
        label: tp.Optional[str],
        info: tp.List[str],
        preproc_path: str,
        bids_root: str,
        preload: bool,
        label_delay: float,
    ):
        """
        Load specified recording from cache.
        """

        raw = read_cache_file(
            preproc_path=preproc_path,
            subject=subject,
            session=session,
            task=task,
            preload=preload,
        )
        events = read_events_file(
            bids_root=bids_root,
            subject=subject,
            session=session,
            task=task,
        )

        self.subject = subject
        self.session = session
        self.label = label

        sfreq = float(raw.info["sfreq"])
        self.samples_per_slice = int(sfreq * window_len)
        self.num_slices = int(len(raw) / self.samples_per_slice)

        if label == "speech":
            self.labels = get_speech_labels(events, raw, label_delay)
        elif label == "voicing":
            self.phoneme_onsets, self.labels = get_voicing_labels(
                events, raw, label_delay
            )
            self.num_slices = len(self.phoneme_onsets)
        elif label == "phoneme":
            self.phoneme_onsets, self.labels = get_phoneme_labels(
                events, raw, label_delay
            )
            self.num_slices = len(self.phoneme_onsets)

        # Add optional additional information
        self.info = {}
        if "dataset" in info:
            self.info["dataset"] = "armeni2022"
        if "subject" in info:
            self.info["subject"] = subject
        if "subject_id" in info:
            self.info["subject_id"] = constants.SUBJECTS.index(subject)
        if "session" in info:
            self.info["session"] = session
        if "sensor_xyz" in info:
            sensor_positions = []
            for ch in raw.info["chs"]:
                pos = ch["loc"][:3]  # Extracts the first three elements: X, Y, Z
                sensor_positions.append(pos)
            sensor_positions = torch.tensor(np.array(sensor_positions))
            self.info["sensor_xyz"] = sensor_positions

        self.raw = raw

    def __len__(self):
        return self.num_slices

    def __getitem__(self, idx):
        return_dict = {}

        if self.label is None:
            data, times = self.raw[
                :, idx * self.samples_per_slice : (idx + 1) * self.samples_per_slice
            ]
        elif self.label == "speech":
            data, times = self.raw[
                :, idx * self.samples_per_slice : (idx + 1) * self.samples_per_slice
            ]
            return_dict["speech"] = self.labels[
                idx * self.samples_per_slice : (idx + 1) * self.samples_per_slice
            ]
        elif self.label == "voicing":
            start = self.phoneme_onsets[idx]
            data, times = self.raw[:, start : start + self.samples_per_slice]
            return_dict["voicing"] = self.labels[idx]
        elif self.label == "phoneme":
            start = self.phoneme_onsets[idx]
            data, times = self.raw[:, start : start + self.samples_per_slice]
            return_dict["phoneme"] = self.labels[idx]

        return_dict["data"] = data
        return_dict["times"] = times

        return_dict["info"] = self.info

        return return_dict
