"""Load data from a specific recording in Gwilliams et al. (2022)"""

import ast
import typing as tp
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

import pnpl.datasets.gwilliams2022.constants as constants
from pnpl.datasets.dataset_utils import read_cache_file, read_events_file, timit_to_arpabet, arpabet

def get_speech_labels(
    events,
    raw,
    offset,
):
    """Use events to determine speech labels."""

    sample_freq = raw.info["sfreq"]
    offset_samples = int(sample_freq * offset)

    word_events = events[
        ["'kind': 'word'" in trial_type for trial_type in list(events["trial_type"])]
    ]
    labels = np.zeros(len(raw))
    for _, word_event in word_events.iterrows():
        onset = float(word_event["onset"])
        duration = float(word_event["duration"])
        t_start = int(onset * sample_freq) + offset_samples
        t_end = int((onset + duration) * sample_freq) + offset_samples
        labels[t_start : t_end + 1] = 1.0

    return labels


def get_voicing_labels(
    events,
    raw,
    phoneme_codes,
    offset,
) -> tp.Tuple[tp.List[int], tp.List[float]]:
    """Use events to determine aligned phoneme onsets and their voicing labels."""

    sample_freq = raw.info["sfreq"]
    offset_samples = int(sample_freq * offset)

    # Filter events with phoneme labels
    phoneme_events = events[
        ["'kind': 'phoneme'" in trial_type for trial_type in list(events["trial_type"])]
    ]

    phoneme_onsets = []
    labels = []

    for _, phoneme_event in phoneme_events.iterrows():
        trial_type = ast.literal_eval(phoneme_event["trial_type"])

        phoneme = trial_type["phoneme"].split("_")[0]  # Remove BIE indicators
        onset_samples = (
            int(float(phoneme_event["onset"]) * sample_freq) + offset_samples
        )
        phonation = phoneme_codes[phoneme_codes["phoneme"] == phoneme][
            "phonation"
        ].item()

        # Label as voiced or unvoiced
        if phonation == "v":
            labels.append(1.0)
            phoneme_onsets.append(onset_samples)
        elif phonation == "uv":
            labels.append(0.0)
            phoneme_onsets.append(onset_samples)

    return phoneme_onsets, labels

def get_phoneme_labels(
    events,
    raw,
    phoneme_codes,
    offset,
) -> tp.Tuple[tp.List[int], tp.List[float]]:
    """Use events to determine aligned phoneme onsets and their phoneme labels."""

    sample_freq = raw.info["sfreq"]
    offset_samples = int(sample_freq * offset)

    # Filter events with phoneme labels
    phoneme_events = events[
        ["'kind': 'phoneme'" in trial_type for trial_type in list(events["trial_type"])]
    ]

    phoneme_onsets = []
    labels = []

    for _, phoneme_event in phoneme_events.iterrows():
        trial_type = ast.literal_eval(phoneme_event["trial_type"])

        phoneme = trial_type["phoneme"].split("_")[0]  # Remove BIE indicators
        onset_samples = (
            int(float(phoneme_event["onset"]) * sample_freq) + offset_samples
        )

        arpabet_phoneme = timit_to_arpabet[phoneme]
        if arpabet_phoneme in arpabet: # Ignore silences
            labels.append(arpabet.index(arpabet_phoneme))
            phoneme_onsets.append(onset_samples)

    return phoneme_onsets, labels

class GwilliamsSubloader(Dataset):
    """Dataset for a specific subject, session, and task."""

    def __init__(
        self,
        subject: str,
        task: str,
        session: str,
        window_len: float,
        label: tp.Optional[str],
        info: tp.List[str],
        preproc_path: str,
        bids_root: str,
        preload: bool,
        label_delay: float,
    ):
        """
        Load specified recording from cache.
        """

        raw = read_cache_file(
            preproc_path=preproc_path,
            subject=subject,
            session=session,
            task=task,
            preload=preload,
        )
        events = read_events_file(
            bids_root=bids_root,
            subject=subject,
            session=session,
            task=task,
        )
        phoneme_codes = pd.read_csv(Path(bids_root) / "phoneme_info.csv")

        self.subject = subject
        self.session = session
        self.label = label

        sfreq = float(raw.info["sfreq"])
        self.samples_per_slice = int(sfreq * window_len)
        self.num_slices = int(len(raw) / self.samples_per_slice)

        if label == "speech":
            self.labels = get_speech_labels(events, raw, label_delay)
        elif label == "voicing":
            self.phoneme_onsets, self.labels = get_voicing_labels(
                events, raw, phoneme_codes, label_delay
            )
            self.num_slices = len(self.phoneme_onsets)
        elif label == "phoneme":
            self.phoneme_onsets, self.labels = get_phoneme_labels(
                events, raw, phoneme_codes, label_delay
            )
            self.num_slices = len(self.phoneme_onsets)

        # Add optional additional information
        self.info = {}
        if "dataset" in info:
            self.info["dataset"] = "gwilliams2022"
        if "subject" in info:
            self.info["subject"] = subject
        if "subject_id" in info:
            self.info["subject_id"] = constants.SUBJECTS.index(subject)
        if "session" in info:
            self.info["session"] = session
        if "sensor_xyz" in info:
            sensor_positions = []
            for ch in raw.info["chs"]:
                pos = ch["loc"][:3]  # Extracts the first three elements: X, Y, Z
                sensor_positions.append(pos)
            sensor_positions = torch.tensor(np.array(sensor_positions))
            self.info["sensor_xyz"] = sensor_positions

        self.raw = raw

    def __len__(self):
        return self.num_slices

    def __getitem__(self, idx):
        return_dict = {}

        if self.label is None:
            data, times = self.raw[
                :, idx * self.samples_per_slice : (idx + 1) * self.samples_per_slice
            ]
        elif self.label == "speech":
            data, times = self.raw[
                :, idx * self.samples_per_slice : (idx + 1) * self.samples_per_slice
            ]
            return_dict["speech"] = self.labels[
                idx * self.samples_per_slice : (idx + 1) * self.samples_per_slice
            ]
        elif self.label == "voicing":
            start = self.phoneme_onsets[idx]
            data, times = self.raw[:, start : start + self.samples_per_slice]
            return_dict["voicing"] = self.labels[idx]
        elif self.label == "phoneme":
            start = self.phoneme_onsets[idx]
            data, times = self.raw[:, start : start + self.samples_per_slice]
            return_dict["phoneme"] = self.labels[idx]

        return_dict["data"] = data
        return_dict["times"] = times

        return_dict["info"] = self.info

        return return_dict
