import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pickle
import re 
from g2p_en import G2p
import numpy as np
import os
from config import AUDIO_ASSETS

import random


g2p = G2p()
PHONE_DEF = [
    'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH'
]
PHONE_DEF_SIL = PHONE_DEF + ['SIL']

def phoneToId(p):
    return PHONE_DEF_SIL.index(p)

phoneToIdDict = {p:phoneToId(p) for p in PHONE_DEF_SIL}
idToPhone = {v: k for k, v in phoneToIdDict.items()}

def idsToPhonemes(seqClassIDs, idToPhone = idToPhone):
    """
    Converts a sequence of phoneme IDs back to their phoneme representations.
    
    Args:
        seqClassIDs (numpy array): The numerical sequence of phoneme IDs.
        idToPhone (dict): A dictionary mapping phoneme IDs back to phonemes.
        
    Returns:
        list: The corresponding phoneme sequence.
    """
    phonemeSeq = [idToPhone[id - 1] for id in seqClassIDs if id > 0]  # -1 because IDs were stored with +1
    return phonemeSeq

class SpeechDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        self.n_days = len(data)
        self.n_trials = sum([len(d["sentenceDat"]) for d in data])

        self.neural_feats = []
        self.phone_seqs = []
        self.neural_time_bins = []
        self.phone_seq_lens = []
        self.days = []
        for day in range(self.n_days):
            for trial in range(len(data[day]["sentenceDat"])):
                self.neural_feats.append(data[day]["sentenceDat"][trial])
                self.phone_seqs.append(data[day]["phonemes"][trial])
                self.neural_time_bins.append(data[day]["sentenceDat"][trial].shape[0])
                self.phone_seq_lens.append(data[day]["phoneLens"][trial])
                self.days.append(day)

    def __len__(self):
        return self.n_trials

    def __getitem__(self, idx):
        neural_feats = torch.tensor(self.neural_feats[idx], dtype=torch.float32)

        if self.transform:
            neural_feats = self.transform(neural_feats)

        return (
            neural_feats,
            torch.tensor(self.phone_seqs[idx], dtype=torch.int32),
            torch.tensor(self.neural_time_bins[idx], dtype=torch.int32),
            torch.tensor(self.phone_seq_lens[idx], dtype=torch.int32),
            torch.tensor(self.days[idx], dtype=torch.int64),
        )


phoneToIdDict = {p:phoneToId(p) for p in PHONE_DEF_SIL}
idToPhone = {v: k for k, v in phoneToIdDict.items()}

def idsToPhonemes(seqClassIDs, idToPhone = idToPhone):
    """
    Converts a sequence of phoneme IDs back to their phoneme representations.
    
    Args:
        seqClassIDs (numpy array): The numerical sequence of phoneme IDs.
        idToPhone (dict): A dictionary mapping phoneme IDs back to phonemes.
        
    Returns:
        list: The corresponding phoneme sequence.
    """
    phonemeSeq = [idToPhone[id - 1] for id in seqClassIDs if id > 0]  # -1 because IDs were stored with +1
    return phonemeSeq

## MAKE MY 

class SpeechSentenceDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        self.n_days = len(data)
        self.n_trials = sum([len(d["sentenceDat"]) for d in data])

        self.neural_feats = []
        self.phone_seqs = []
        self.neural_time_bins = []
        self.phone_seq_lens = []
        self.days = []
        self.sentences = []
        for day in range(self.n_days):
            for trial in range(len(data[day]["sentenceDat"])):
                self.neural_feats.append(data[day]["sentenceDat"][trial])
                self.phone_seqs.append(data[day]["phonemes"][trial])
                self.neural_time_bins.append(data[day]["sentenceDat"][trial].shape[0])
                self.phone_seq_lens.append(data[day]["phoneLens"][trial])
                self.days.append(day)
                self.sentences.append(data[day]["transcriptions"][trial])

    def __len__(self):
        return self.n_trials

    def __getitem__(self, idx):
        neural_feats = torch.tensor(self.neural_feats[idx], dtype=torch.float32)

        if self.transform:
            neural_feats = self.transform(neural_feats)

        return (
            neural_feats,
            torch.tensor(self.phone_seqs[idx], dtype=torch.int32),
            torch.tensor(self.neural_time_bins[idx], dtype=torch.int32),
            torch.tensor(self.phone_seq_lens[idx], dtype=torch.int32),
            torch.tensor(self.days[idx], dtype=torch.int64),
            self.sentences[idx]
        )




class SpeechSentenceDataset_MFCC_assets(Dataset):
    def __init__(self, data, transform=None,asset_path=None):
        self.data = data
        self.transform = transform
        self.n_days = len(data)
        self.n_trials = sum([len(d["sentenceDat"]) for d in data])

        self.neural_feats = []
        self.phone_seqs = []
        self.neural_time_bins = []
        self.phone_seq_lens = []
        self.days = []
        self.sentences = []
        for day in range(self.n_days):
            for trial in range(len(data[day]["sentenceDat"])):
                self.neural_feats.append(data[day]["sentenceDat"][trial])
                self.phone_seqs.append(data[day]["phonemes"][trial])
                self.neural_time_bins.append(data[day]["sentenceDat"][trial].shape[0])
                self.phone_seq_lens.append(data[day]["phoneLens"][trial])
                self.days.append(day)
                self.sentences.append(data[day]["transcriptions"][trial])
        if asset_path is None:
            self.mfcc = [None]*self.n_trials
        else:
            self.mfcc = pickle.load(open(f"{asset_path}", "rb"))


    def __len__(self):
        return self.n_trials

    def __getitem__(self, idx):
        neural_feats = torch.tensor(self.neural_feats[idx], dtype=torch.float32)

        if self.transform:
            neural_feats = self.transform(neural_feats)

        mfcc = self.mfcc[idx]
        return (
            neural_feats,
            torch.tensor(self.phone_seqs[idx], dtype=torch.int32),
            torch.tensor(self.neural_time_bins[idx], dtype=torch.int32),
            torch.tensor(self.phone_seq_lens[idx], dtype=torch.int32),
            torch.tensor(self.days[idx], dtype=torch.int64),
            self.sentences[idx],
            mfcc 
        )




class SpliceMixAugmentedDataset(Dataset):
    def __init__(
        self,
        base_dataset,
        mix_prob=0.5,
        silence_idx=40,
        time_per_phoneme=14
    ):
        """
        :param base_dataset: Your original dataset
        :param mix_prob: Probability of applying the splice augmentation
        :param silence_idx: The phoneme index that indicates silence (SIL)
        :param time_per_phoneme: Approx # time frames per phoneme for neural data
        """
        self.base_dataset = base_dataset
        self.mix_prob = mix_prob
        self.silence_idx = silence_idx
        self.time_per_phoneme = time_per_phoneme

        # Precompute indices by day for fast same-day sampling
        self.indices_by_day = {}
        for i in range(len(base_dataset)):
            _, _, _, _, day_i, _ = base_dataset[i]
            day_i = int(day_i)
            if day_i not in self.indices_by_day:
                self.indices_by_day[day_i] = []
            self.indices_by_day[day_i].append(i)

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        # Fetch the primary sample
        neural_feats, phone_seq, neural_len, phone_len, day, sentence = self.base_dataset[idx]
        day = int(day)

        # Convert to list/string for splicing
        phone_seq = phone_seq.tolist()
        words1 = sentence.split()

        # If we don't splice, just return original
        if random.random() > self.mix_prob:
            return (
                neural_feats,
                torch.tensor(phone_seq, dtype=torch.int32),
                neural_len,
                phone_len,
                torch.tensor(day, dtype=torch.int32),
                sentence
            )

        # Must have at least 2 same-day samples to splice
        same_day_indices = self.indices_by_day[day]
        if len(same_day_indices) < 2:
            return (
                neural_feats,
                torch.tensor(phone_seq, dtype=torch.int32),
                neural_len,
                phone_len,
                torch.tensor(day, dtype=torch.int32),
                sentence
            )

        # Pick a different same-day sample
        idx2 = idx
        while idx2 == idx:
            idx2 = random.choice(same_day_indices)

        nf2, ps2, nl2, pl2, _, sent2 = self.base_dataset[idx2]
        ps2 = ps2.tolist()
        words2 = sent2.split()

        # Find internal silences for each phone sequence
        def internal_silences(seq):
            return [i for i in range(1, len(seq) - 1) if seq[i] == self.silence_idx]

        sil_idx_1 = internal_silences(phone_seq)
        sil_idx_2 = internal_silences(ps2)

        # We also need at least 1 actual word boundary in each sentence
        num_bound_1 = len(words1) - 1
        num_bound_2 = len(words2) - 1

        if (not sil_idx_1 or not sil_idx_2 or 
            num_bound_1 < 1 or num_bound_2 < 1):
            # Not enough silences or not enough words to splice
            return (
                neural_feats,
                torch.tensor(phone_seq, dtype=torch.int32),
                neural_len,
                phone_len,
                torch.tensor(day, dtype=torch.int32),
                sentence
            )

        # Choose random boundary index for sample1 and sample2
        max_bound_1 = min(len(sil_idx_1), num_bound_1)
        max_bound_2 = min(len(sil_idx_2), num_bound_2)
        i_1 = random.randint(0, max_bound_1 - 1)
        i_2 = random.randint(0, max_bound_2 - 1)

        # Phone boundary indices
        cut_1_phone = sil_idx_1[i_1]
        cut_2_phone = sil_idx_2[i_2]

        # Word boundary indices: boundary i_1 is between words1[i_1] and words1[i_1+1]
        # so we keep words1[:i_1+1] from the first,
        # and words2[i_2+1:] from the second
        new_phone_seq = phone_seq[: cut_1_phone + 1] + ps2[cut_2_phone + 1:]
        new_sentence = words1[: i_1 + 1] + words2[i_2 + 1:]

        # For neural data, cut at (cut_1_phone+1)*time_per_phoneme in the first sample
        # and from (cut_2_phone+1)*time_per_phoneme onward in the second sample
        cut_1_n = min((cut_1_phone + 1) * self.time_per_phoneme, len(neural_feats))
        cut_2_n = min((cut_2_phone + 1) * self.time_per_phoneme, len(nf2))
        new_neural = torch.cat(
            [neural_feats[:cut_1_n], nf2[cut_2_n:]],
            dim=0
        )
        # --- 7) Validate lengths for CTC ---
        # Must have T >= target length for CTC, i.e., new_neural.shape[0] >= len(new_phone_seq)
        if new_neural.shape[0] < len(new_phone_seq):
            # Fallback to original sample or skip
            return (
                neural_feats,
                torch.tensor(phone_seq, dtype=torch.int32),
                neural_len,
                phone_len,
                torch.tensor(day, dtype=torch.int32),
                sentence
            )

        # Also check zero length
        if new_neural.shape[0] == 0 or len(new_phone_seq) == 0:
            return (
                neural_feats,
                torch.tensor(phone_seq, dtype=torch.int32),
                neural_len,
                phone_len,
                torch.tensor(day, dtype=torch.int32),
                sentence
            )

        # If all good, return the augmented sample
        return (
            new_neural,
            torch.tensor(new_phone_seq, dtype=torch.int32),
            torch.tensor(new_neural.shape[0], dtype=torch.int32),
            torch.tensor(len(new_phone_seq), dtype=torch.int32),
            torch.tensor(day, dtype=torch.int32),
            " ".join(new_sentence)
        )


## UTILS FUNCTIONS
def getDatasetLoaders(
    datasetName,
    batchSize,
    shuffle_train=True,
    use_splice_mix=False,
    splice_mix_prob=0.5,
    silence_idx=40,
    time_per_phoneme=14,
):
    with open(datasetName, "rb") as handle:
        loadedData = pickle.load(handle)

    def _padding(batch):
        X, y, X_lens, y_lens, days, sentence = zip(*batch)
        X_padded = pad_sequence(X, batch_first=True, padding_value=0)
        y_padded = pad_sequence(y, batch_first=True, padding_value=0)

        return (
            X_padded,
            y_padded,
            torch.stack(X_lens),
            torch.stack(y_lens),
            torch.stack(days),
            sentence,
        )

    train_ds = SpeechSentenceDataset(loadedData["train"])

    if use_splice_mix:
        train_ds = SpliceMixAugmentedDataset(
            train_ds,
            mix_prob=splice_mix_prob,
            silence_idx=silence_idx,
            time_per_phoneme=time_per_phoneme,
        )
    test_ds = SpeechSentenceDataset(loadedData["test"])
    competition_ds = SpeechSentenceDataset(loadedData["competition"])

    train_loader = DataLoader(
        train_ds,
        batch_size=batchSize,
        shuffle=shuffle_train,
        num_workers=0,
        pin_memory=True,
        collate_fn=_padding,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=_padding,
    )

    competition_loader = DataLoader(
        competition_ds,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=_padding,
    )
    return train_loader, test_loader, competition_loader, loadedData
















############### DIPHONEMES


# import torch
# from torch.utils.data import Dataset, DataLoader
# from torch.nn.utils.rnn import pad_sequence
# import pickle

class DiphonemeDataset(Dataset):
    def __init__(self, data, valid_diphonemes, transform=None, return_both=False):
        """
        Dataset class that processes phoneme sequences and converts them to diphonemes.

        Args:
            data (list): The dataset containing phoneme sequences, transcriptions, and neural features.
            valid_diphonemes (set): A set of allowed diphonemes.
            transform (callable, optional): A function to apply to the neural features.
            return_both (bool): If True, returns both phoneme and diphoneme sequences.
        """
        self.data = data
        self.transform = transform
        self.valid_diphonemes = valid_diphonemes
        self.return_both = return_both
        self.n_days = len(data)
        self.n_trials = sum([len(d["sentenceDat"]) for d in data])

        self.neural_feats = []
        self.diphoneme_seqs = []
        self.phoneme_seqs = []  # Store phonemes as well
        self.neural_time_bins = []
        self.diphoneme_seq_lens = []
        self.phoneme_seq_lens = []
        self.days = []
        self.sentences = []

        for day in range(self.n_days):
            for trial in range(len(data[day]["sentenceDat"])):
                self.neural_feats.append(data[day]["sentenceDat"][trial])
                
                phoneme_seq = data[day]["phonemes"][trial]
                diphoneme_seq = self.convert_to_diphonemes(phoneme_seq)

                self.phoneme_seqs.append(phoneme_seq)
                self.diphoneme_seqs.append(diphoneme_seq)
                self.neural_time_bins.append(data[day]["sentenceDat"][trial].shape[0])
                self.phoneme_seq_lens.append(len(phoneme_seq))
                self.diphoneme_seq_lens.append(len(diphoneme_seq))
                self.days.append(day)
                self.sentences.append(data[day]["transcriptions"][trial])

    def __len__(self):
        return self.n_trials

    def __getitem__(self, idx):
        neural_feats = torch.tensor(self.neural_feats[idx], dtype=torch.float32)

        if self.transform:
            neural_feats = self.transform(neural_feats)

        # Convert phonemes and diphonemes to tensors
        phoneme_tensor = torch.tensor(self.phoneme_seqs[idx], dtype=torch.int32)
        diphoneme_tensor = torch.tensor(self.diphoneme_seqs[idx], dtype=torch.int32)

        # Return either phonemes, diphonemes, or both
        if self.return_both:
            return (
                neural_feats,
                phoneme_tensor,
                diphoneme_tensor,
                torch.tensor(self.neural_time_bins[idx], dtype=torch.int32),
                torch.tensor(self.phoneme_seq_lens[idx], dtype=torch.int32),
                torch.tensor(self.diphoneme_seq_lens[idx], dtype=torch.int32),
                torch.tensor(self.days[idx], dtype=torch.int64),
                self.sentences[idx]
            )
        else:
            return (
                neural_feats,
                diphoneme_tensor,
                torch.tensor(self.neural_time_bins[idx], dtype=torch.int32),
                torch.tensor(self.diphoneme_seq_lens[idx], dtype=torch.int32),
                torch.tensor(self.days[idx], dtype=torch.int64),
                self.sentences[idx]
            )

    def convert_to_diphonemes(self, phoneme_seq):
        """ Convert a phoneme sequence into diphonemes, keeping 'SIL' separate. """
        diphoneme_seq = []
        for i in range(len(phoneme_seq) - 1):
            if phoneme_seq[i] == "SIL" or phoneme_seq[i+1] == "SIL":
                diphoneme_seq.append(phoneme_seq[i])  # Keep SIL separate
            else:
                diphone = f"{phoneme_seq[i]}_{phoneme_seq[i+1]}"
                if diphone in self.valid_diphonemes:
                    diphoneme_seq.append(diphone)
        
        # Add last phoneme if it's "SIL"
        if phoneme_seq[-1] == "SIL":
            diphoneme_seq.append("SIL")
            
        return diphoneme_seq


def getDiphonemeDatasetLoaders(datasetName, valid_diphonemes, batchSize, shuffle_train=True, return_both=False):
    """ Loads data and creates dataset loaders with diphonemes and optionally phonemes. """
    
    with open(datasetName, "rb") as handle:
        loadedData = pickle.load(handle)

    def _padding(batch):
        if return_both:
            X, y_phones, y_diphones, X_lens, y_phone_lens, y_diphone_lens, days, sentence = zip(*batch)
            y_phones_padded = pad_sequence(y_phones, batch_first=True, padding_value=0)
            y_diphones_padded = pad_sequence(y_diphones, batch_first=True, padding_value=0)
        else:
            X, y, X_lens, y_lens, days, sentence = zip(*batch)
            y_padded = pad_sequence(y, batch_first=True, padding_value=0)

        X_padded = pad_sequence(X, batch_first=True, padding_value=0)

        if return_both:
            return (
                X_padded,
                y_phones_padded,
                y_diphones_padded,
                torch.stack(X_lens),
                torch.stack(y_phone_lens),
                torch.stack(y_diphone_lens),
                torch.stack(days),
                sentence,
            )
        else:
            return (
                X_padded,
                y_padded,
                torch.stack(X_lens),
                torch.stack(y_lens),
                torch.stack(days),
                sentence,
            )

    train_ds = DiphonemeDataset(loadedData["train"], valid_diphonemes, transform=None, return_both=return_both)
    test_ds = DiphonemeDataset(loadedData["test"], valid_diphonemes, return_both=return_both)
    competition_ds = DiphonemeDataset(loadedData["competition"], valid_diphonemes, return_both=return_both)

    train_loader = DataLoader(
        train_ds,
        batch_size=batchSize,
        shuffle=shuffle_train,
        num_workers=0,
        pin_memory=True,
        collate_fn=_padding,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=_padding,
    )

    competition_loader = DataLoader(
        competition_ds,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=_padding,
    )
    
    return train_loader, test_loader, competition_loader, loadedData






####### V2


def getDatasetLoaders_V2(
    datasetName,
    batchSize,
    shuffle_train=True,
    roi="sm", # "sm" or "broca"
    mode="both", # "spike count", "spikepower" or "both"
    include_prego=False, # if True, include pre-go neural data
    mfcc_assets=None # could None, trimmed or interpolated
):
    with open(datasetName, "rb") as handle:
        loadedData = pickle.load(handle)

    def collate_fn_simple(batch):
        batch_out = {}
        X, y, X_lens, y_lens, days, sentence = zip(*batch)
        # Pad neural_feats and phone_seq
        batch_out["neural_feats"] = pad_sequence(X, batch_first=True, padding_value=0)
        batch_out["phone_seq"] = pad_sequence(y, batch_first=True, padding_value=0)



        # Stack simple fields
        batch_out["neural_time_bins"] = torch.stack(X_lens)
        batch_out["phone_seq_len"] = torch.stack(y_lens)
        batch_out["day"] = torch.stack(days)

        # Keep list for variable-length/string fields
        batch_out["sentence"] = sentence
        # batch_out["audio_file"] = [item["audio_file"] for item in batch]
        # batch_out["go_onset"] = [item["go_onset"] for item in batch]
        # batch_out["speech_label"] = [item["speech_label"] for item in batch]

        return batch_out

    

    def collate_fn(batch):
        batch_out = {}
        X, y, X_lens, y_lens, days, sentence, mfcc = zip(*batch)
        # Pad neural_feats and phone_seq
        batch_out["neural_feats"] = pad_sequence(X, batch_first=True, padding_value=0)
        batch_out["phone_seq"] = pad_sequence(y, batch_first=True, padding_value=0)



        # Stack simple fields
        batch_out["neural_time_bins"] = torch.stack(X_lens)
        batch_out["phone_seq_len"] = torch.stack(y_lens)
        batch_out["day"] = torch.stack(days)

        # Keep list for variable-length/string fields
        batch_out["sentence"] = sentence
        # batch_out["audio_file"] = [item["audio_file"] for item in batch]
        batch_out["mfcc"] = mfcc
        # batch_out["go_onset"] = [item["go_onset"] for item in batch]
        # batch_out["speech_label"] = [item["speech_label"] for item in batch]

        return batch_out

    if mfcc_assets is not None:

        if mfcc_assets == "trimmed":
            train_mfcc_assets = f"{AUDIO_ASSETS}/train_mfcc_trimmed.pkl"
            test_mfcc_assets = f"{AUDIO_ASSETS}/test_mfcc_trimmed.pkl"
        
        if mfcc_assets == "interpolated":
            train_mfcc_assets = f"{AUDIO_ASSETS}/interpolated_train_mfcc.pkl"
            test_mfcc_assets = f"{AUDIO_ASSETS}/interpolated_test_mfcc.pkl"
    else:
        train_mfcc_assets = None
        test_mfcc_assets = None

    train_ds = SpeechSentenceDataset_MFCC_assets(loadedData["train"],asset_path=train_mfcc_assets)

    test_ds = SpeechSentenceDataset_MFCC_assets(loadedData["test"],asset_path=test_mfcc_assets)
    competition_ds = SpeechSentenceDataset(loadedData["competition"])


    # competition_ds = SpeechSentenceDataset(loadedData["competition"])

    train_loader = DataLoader(
        train_ds,
        batch_size=batchSize,
        shuffle=shuffle_train,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    competition_loader = DataLoader(
        competition_ds,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn_simple,
    )

    return train_loader, test_loader, competition_loader, loadedData




























### NEW VERSION THAT INCLUDES PRE-GO



class SpeechSentenceDataset_v3(Dataset):
    def __init__(self, data, split="train", transform=None, roi="both", mode="both", only_attempted=False, indices=None, audio_files=None, include_prego=False):

        #split could be "train" or "test"
        #roi could be "both", "broca", "sm"  -> broca is area 44, sm is area 6v
        #mode could be "both", "spike count", "spikepower"
        #In tx1 and spikePow first 128 channels relate to Brodmann area 6v and the last 128 to area 44.
        self.include_prego = include_prego
        self.split = split

        if split == "train" and indices is None:
            indices_path = f"{AUDIO_ASSETS}/train_audio/train_indices.pkl"
            with open(indices_path, "rb") as handle:
                indices = pickle.load(handle)
        elif split == "test" and indices is None:
            indices_path = f"{AUDIO_ASSETS}/test_audio/test_indices.pkl"
            with open(indices_path, "rb") as handle:
                indices = pickle.load(handle)

        if audio_files is None:
            if split == "train":
                audio_files = [f"{AUDIO_ASSETS}/train_audio/generated_"+str(i)+".wav" for i in range(len(os.listdir(f"{AUDIO_ASSETS}/train_audio")))]
            elif split == "test":
                audio_files = [f"{AUDIO_ASSETS}/test_audio/generated_"+str(i)+".wav" for i in range(len(os.listdir(f"{AUDIO_ASSETS}/test_audio")))]
       
        self.data = data[split]
        data = data[split]
        self.transform = transform
        self.n_days = len(data)
        self.n_trials = sum([len(d["transcriptions"]) for d in data])
        self.indices = indices
        self.audio_files = audio_files

        print("Number of trials: ", self.n_trials)
        print("Number of days: ", self.n_days)

        self.neural_feats = []
        self.phone_seqs = []
        self.neural_time_bins = []
        self.phone_seq_lens = []
        self.days = []
        self.sentences = []
        self.speaking_modes = []
        self.phonemes = []
        self.go_onset = []
        self.mfcc = []
        self.speech_labels = []
        
        for day in range(self.n_days):


            for trial in range(len(data[day]["transcriptions"])):


                #BOTH BROCA AND SM
                if roi == "both":
                    if mode == "spike_count":
                        self.neural_feats.append(data[day]["tx1"][trial], data[day]["tx1"][trial])
                    elif mode == "spikepower":
                        self.neural_feats.append(data[day]["spikePow"][trial])
                    elif mode == "both":
                        self.neural_feats.append(np.concatenate([data[day]["tx1"][trial], data[day]["spikePow"][trial]], axis=1))
            
                #SELECT ONLY SM
                if roi == "sm":
                    if mode == "spike_count":
                        self.neural_feats.append(data[day]["tx1"][trial][:, :128])
                    elif mode == "spikepower":
                        self.neural_feats.append(data[day]["spikePow"][trial][:, :128])
                    elif mode == "both":
                        self.neural_feats.append(np.concatenate([data[day]["tx1"][trial][:, :128], data[day]["spikePow"][trial][:, :128]], axis=1))


                #SELECT ONLY BROCA

                if roi == "broca":
                    if mode == "spike_count":
                        self.neural_feats.append(data[day]["tx1"][trial][:, 128:])
                    elif mode == "spikepower":
                        self.neural_feats.append(data[day]["spikePow"][trial][:, 128:])
                    elif mode == "both":
                        self.neural_feats.append(np.concatenate([data[day]["tx1"][trial][:, 128:], data[day]["spikePow"][trial][:, 128:]], axis=1))

                self.phone_seqs.append(data[day]["phonemeIDs"][trial])
                self.phonemes.append(data[day]["phonemes"][trial])
                self.neural_time_bins.append(data[day]["tx1"][trial].shape[0])
                self.phone_seq_lens.append(data[day]["phoneLens"][trial])
                self.days.append(day)
                self.sentences.append(data[day]["transcriptions"][trial])
                self.speaking_modes.append(data[day]["speakingMode"][trial])
                self.go_onset.append(data[day]["goPeriodOnset"][trial])
                self.mfcc.append(data[day]["mfcc"][trial])
                self.speech_labels.append(data[day]["speechLabel"][trial])

        if only_attempted:
            # Filter out trials that were not attempted
            attempted_indices = [i for i, mode in enumerate(self.speaking_modes) if mode == "attempted speaking"]
            self.neural_feats = [self.neural_feats[i] for i in attempted_indices]
            self.phone_seqs = [self.phone_seqs[i] for i in attempted_indices]
            self.neural_time_bins = [self.neural_time_bins[i] for i in attempted_indices]
            self.phone_seq_lens = [self.phone_seq_lens[i] for i in attempted_indices]
            self.days = [self.days[i] for i in attempted_indices]
            self.sentences = [self.sentences[i] for i in attempted_indices]
            self.speaking_modes = [self.speaking_modes[i] for i in attempted_indices]
            self.mfcc = [self.mfcc[i] for i in attempted_indices]
            self.go_onset = [self.go_onset[i] for i in attempted_indices]
            print("Number of trials after filtering: ", len(self.neural_feats))
        #

        if self.indices is not None:
            # Filter the dataset based on the provided indices
            self.neural_feats = [self.neural_feats[i] for i in self.indices]
            self.phone_seqs = [self.phone_seqs[i] for i in self.indices]
            self.neural_time_bins = [self.neural_time_bins[i] for i in self.indices]
            self.phone_seq_lens = [self.phone_seq_lens[i] for i in self.indices]
            self.days = [self.days[i] for i in self.indices]
            self.sentences = [self.sentences[i] for i in self.indices]
            self.speaking_modes = [self.speaking_modes[i] for i in self.indices]
            self.mfcc = [self.mfcc[i] for i in self.indices]
            self.go_onset = [self.go_onset[i] for i in self.indices]
            self.speech_labels = [self.speech_labels[i] for i in self.indices]
            print("Number of trials after filtering by indices: ", len(self.neural_feats))
        

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        neural_feats = torch.tensor(self.neural_feats[idx], dtype=torch.float32)

        if self.transform:
            neural_feats = self.transform(neural_feats)

        X_len = torch.tensor(self.neural_time_bins[idx], dtype=torch.int32)
        if not self.include_prego:
            #only return neural features from go onset onwards
            neural_feats = neural_feats[self.go_onset[idx]:, :]
            X_len = X_len - self.go_onset[idx]
            
            

        return {
            "neural_feats": neural_feats,
            "phone_seq": torch.tensor(self.phone_seqs[idx], dtype=torch.int32),
            "neural_time_bins": X_len,
            "phone_seq_len": torch.tensor(self.phone_seq_lens[idx], dtype=torch.int32),
            "day": torch.tensor(self.days[idx], dtype=torch.int64),
            "sentence": self.sentences[idx],
            "audio_file": self.audio_files[idx] if self.audio_files is not None else None,
            "mfcc": self.mfcc[idx] if self.mfcc is not None else None,
            "go_onset": self.go_onset[idx] if self.go_onset is not None else None,
            "speech_label": self.speech_labels[idx] if self.speech_labels is not None else None,
        }


def getDatasetLoaders_V3(
    datasetName,
    batchSize,
    shuffle_train=True,
    roi="sm", # "sm" or "broca"
    mode="both", # "spike count", "spikepower" or "both"
    include_prego=False, # if True, include pre-go neural data
):
    with open(datasetName, "rb") as handle:
        loadedData = pickle.load(handle)

    def collate_fn(batch):
        batch_out = {}

        # Pad neural_feats and phone_seq
        batch_out["neural_feats"] = pad_sequence(
            [item["neural_feats"] for item in batch],
            batch_first=True,
            padding_value=0
        )
        batch_out["phone_seq"] = pad_sequence(
            [item["phone_seq"] for item in batch],
            batch_first=True,
            padding_value=0
        )

        # Stack simple fields
        batch_out["neural_time_bins"] = torch.stack([item["neural_time_bins"] for item in batch])
        batch_out["phone_seq_len"] = torch.stack([item["phone_seq_len"] for item in batch])
        batch_out["day"] = torch.stack([item["day"] for item in batch])

        # Keep list for variable-length/string fields
        batch_out["sentence"] = [item["sentence"] for item in batch]
        batch_out["audio_file"] = [item["audio_file"] for item in batch]
        batch_out["mfcc"] = [item["mfcc"] for item in batch]
        batch_out["go_onset"] = [item["go_onset"] for item in batch]
        batch_out["speech_label"] = [item["speech_label"] for item in batch]

        return batch_out


    train_ds = SpeechSentenceDataset_v3(loadedData,split="train", roi=roi, mode=mode,include_prego=include_prego)
    test_ds = SpeechSentenceDataset_v3(loadedData,split="test", roi=roi, mode=mode,include_prego=include_prego)


    # competition_ds = SpeechSentenceDataset(loadedData["competition"])

    train_loader = DataLoader(
        train_ds,
        batch_size=batchSize,
        shuffle=shuffle_train,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=batchSize,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    competition_loader = None
    return train_loader, test_loader, competition_loader, loadedData
