"""
Dataset Creation Script for ECoG-Phoneme-Transformer

This script processes ECoG data recordings from speech experiments,
extracts features, and creates a structured dataset for training and testing
speech decoding models.
"""

import os
import re
import pickle
import numpy as np
import scipy.io
from g2p_en import G2p
from collections import Counter

# --- 1. PHONEME UTILITIES ---

# Define phoneme mappings
g2p = G2p()
PHONE_DEF = [
    'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',  'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH'
]
PHONE_DEF_SIL = PHONE_DEF + ['SIL']  # Add silence token

# Create lookup dictionaries
phoneToIdDict = {p: i for i, p in enumerate(PHONE_DEF_SIL)}
idToPhone = {i: p for i, p in enumerate(PHONE_DEF_SIL)}

def phoneToId(phoneme):
    """Convert phoneme string to numeric ID"""
    return phoneToIdDict[phoneme]

def phonemesFromSentence(sentence):
    """
    Convert a text sentence to a sequence of phonemes
    
    Args:
        sentence: Text sentence to convert
        
    Returns:
        tuple: (phoneme_list, phoneme_ids_array)
    """
    # Clean and normalize the transcription
    clean_text = str(sentence).strip()
    clean_text = re.sub(r'[^a-zA-Z\- \']', '', clean_text)
    clean_text = clean_text.replace('--', '').lower()
    
    # Convert to phonemes
    phonemes = []
    for p in g2p(clean_text):
        if p == ' ':
            phonemes.append('SIL')  # Add silence between words
        else:
            p = re.sub(r'[0-9]', '', p)  # Remove stress markers
            if re.match(r'[A-Z]+', p):   # Only keep phoneme tokens
                phonemes.append(p)
    
    # Add final silence
    phonemes.append('SIL')
    
    # Convert to IDs with fixed-length array
    seq_length = len(phonemes)
    max_seq_len = 100
    phoneme_ids = np.zeros(max_seq_len, dtype=np.int32)
    phoneme_ids[:seq_length] = [phoneToId(p) + 1 for p in phonemes]  # +1 for zero padding
    
    return phonemes, phoneme_ids


def _clean_sentence_for_matching(sentence):
    """Cleans a sentence string for robust matching, mirroring phoneme cleaning."""
    s = str(sentence).strip().lower()
    # Keep only letters, hyphens, apostrophes, and spaces.
    s = re.sub(r'[^a-z\- \']', '', s)
    # Remove double hyphens, mirroring the phoneme cleaning logic.
    s = s.replace('--', '')
    # Collapse multiple spaces and remove leading/trailing whitespace for robustness.
    s = re.sub(r'\s+', ' ', s).strip()
    return s


# --- 2. FEATURE EXTRACTION UTILITIES ---

def extract_normalized_mfccs(raw_data, num_coeffs=14):
    """
    Extract MFCCs and normalize across the entire session
    
    Args:
        raw_data: Dictionary containing session data
        num_coeffs: Number of coefficients to extract (default: 14)
        
    Returns:
        tuple: (numpy.ndarray of normalized MFCCs, dict of normalization stats)
    """
    # Extract MFCCs
    mfccs = raw_data["audioFeatures"][:, :num_coeffs]
    
    # Normalize each coefficient independently
    means = np.mean(mfccs, axis=0)
    stds = np.std(mfccs, axis=0)
    normalized_mfccs = (mfccs - means) / (stds + 1e-8)
    
    stats = {'mean': means, 'std': stds}
    return normalized_mfccs, stats

def extract_normalized_envelope(raw_data):
    """
    Extract audio envelope and normalize across the entire session
    
    Args:
        raw_data: Dictionary containing session data
        
    Returns:
        tuple: (numpy.ndarray of normalized audio envelope, dict of normalization stats)
    """
    # Extract envelope
    envelope = raw_data["audioEnvelope"].squeeze()
    
    # Normalize
    env_mean = np.mean(envelope)
    env_std = np.std(envelope)
    normalized_envelope = (envelope - env_mean) / (env_std + 1e-8)
    
    stats = {'mean': env_mean, 'std': env_std}
    return normalized_envelope, stats


# --- 3. SPEECH LABEL GENERATION ---

def create_speech_labels(raw_data, threshold_factor=0.1, min_speech_duration=3, 
                         min_silence_duration=3, merge_window=10):
    """
    Create two versions of speech labels:
    - Restricted: Speech only detected in go periods
    - Unrestricted: Speech detected in both delay and go periods
    
    Args:
        raw_data: Dictionary containing session data
        threshold_factor: Fraction of max amplitude for threshold
        min_speech_duration: Minimum frames for speech segment
        min_silence_duration: Minimum frames for silence segment
        merge_window: Window for merging nearby speech segments
    
    Returns:
        tuple: (restricted_labels, unrestricted_labels)
    """
    # Extract data
    envelope = raw_data["audioEnvelope"].squeeze()
    delay_epochs = raw_data["delayTrialEpochs"]
    go_epochs = raw_data["goTrialEpochs"]
    
    # Create arrays for both approaches
    restricted_labels = np.zeros_like(envelope, dtype=int)
    unrestricted_labels = np.zeros_like(envelope, dtype=int)
    
    # Process each trial
    for trial_idx in range(len(go_epochs)):
        # Get trial boundaries
        delay_start, delay_end = delay_epochs[trial_idx]
        go_start, go_end = go_epochs[trial_idx]
        
        # --------- RESTRICTED APPROACH (go periods only) ---------
        # Process go period
        go_envelope = envelope[go_start:go_end+1]
        max_amp = np.max(go_envelope)
        threshold = threshold_factor * max_amp
        binary = (go_envelope > threshold).astype(int)
        processed = process_binary_labels(binary, min_speech_duration, min_silence_duration, merge_window)
        restricted_labels[go_start:go_end+1] = processed
        
        # --------- UNRESTRICTED APPROACH (all periods) ---------
        # Process entire trial (delay + go)
        trial_start = delay_start
        trial_end = go_end
        trial_envelope = envelope[trial_start:trial_end+1]
        max_amp = np.max(trial_envelope)
        threshold = threshold_factor * max_amp
        binary = (trial_envelope > threshold).astype(int)
        processed = process_binary_labels(binary, min_speech_duration, min_silence_duration, merge_window)
        unrestricted_labels[trial_start:trial_end+1] = processed
    
    return restricted_labels, unrestricted_labels

def process_binary_labels(binary, min_speech_duration, min_silence_duration, merge_window):
    """Process binary labels with all post-processing steps"""
    # Remove short speech segments
    processed = remove_short_segments(binary, 1, min_speech_duration)
    # Remove short silence gaps
    processed = remove_short_segments(processed, 0, min_silence_duration)
    # Merge nearby speech segments
    processed = merge_speech_segments(processed, merge_window)
    return processed

def remove_short_segments(binary_array, segment_value, min_duration):
    """Remove segments shorter than min_duration"""
    result = binary_array.copy()
    
    # Find contiguous regions
    from scipy import ndimage
    labeled, num_features = ndimage.label(binary_array == segment_value)
    
    # Remove short segments
    for i in range(1, num_features + 1):
        segment = (labeled == i)
        if np.sum(segment) < min_duration:
            if segment_value == 1:
                result[segment] = 0  # Remove short speech
            else:
                result[segment] = 1  # Fill short silence
    
    return result

def merge_speech_segments(binary_array, max_gap):
    """Merge speech segments separated by short gaps"""
    result = binary_array.copy()
    
    # Find silence segments
    from scipy import ndimage
    labeled, num_features = ndimage.label(binary_array == 0)
    
    # Fill short gaps
    for i in range(1, num_features + 1):
        segment = (labeled == i)
        if np.sum(segment) <= max_gap:
            # Only fill gaps between speech (not at start/end)
            segment_idxs = np.where(segment)[0]
            if segment_idxs[0] > 0 and segment_idxs[-1] < len(binary_array)-1:
                if binary_array[segment_idxs[0]-1] == 1 and binary_array[segment_idxs[-1]+1] == 1:
                    result[segment] = 1
    
    return result


# --- 4. TRIAL EXTRACTION UTILITIES ---

def extract_trial_ranges(raw_data):
    """
    Extract the full range (delay + go) for each trial
    
    Args:
        raw_data: Dictionary containing session data
        
    Returns:
        numpy.ndarray: Array of [start, end] indices for each trial
    """
    delay_epochs = raw_data["delayTrialEpochs"]
    go_epochs = raw_data["goTrialEpochs"]
    
    # Create combined trial ranges
    num_trials = len(delay_epochs)
    trial_ranges = np.zeros((num_trials, 2), dtype=int)
    
    for i in range(num_trials):
        # Start of delay period
        trial_ranges[i, 0] = delay_epochs[i, 0]
        # End of go period
        trial_ranges[i, 1] = go_epochs[i, 1]
    
    return trial_ranges

def calculate_go_onset(trial_start, go_start):
    """
    Calculate the timestep within a trial where the go period begins.
    
    Args:
        trial_start: The absolute timestep where the trial begins
        go_start: The absolute timestep where the go period begins
    
    Returns:
        int: The index within the trial where the go period starts
    """
    return go_start - trial_start


# --- 5. FEATURE LOADING FUNCTIONS ---

def loadFeatures(sessionPath, trial_part="full"):
    """
    Load features from a .mat file without normalization.
    
    Args:
        sessionPath: Path to the .mat file
        trial_part: Which part of trial to extract ("full" or "go")
        
    Returns:
        tuple: (dict containing all session features, dict containing audio norm stats)
    """
    dat = scipy.io.loadmat(sessionPath)

    # Get key info
    n_trials = dat["goTrialEpochs"].shape[0]
    session_name = sessionPath.split("/")[-1].split("_")[0]
    
    # Create consistent lists
    session_name_list = [session_name] * n_trials
    speakingMode_list = [dat["speakingMode"]] * n_trials

    # Extract trial boundaries
    goTrialEpochs = dat['goTrialEpochs']
    delayAndGoEpochs = extract_trial_ranges(dat)

    # Normalize features across session
    normalized_mfccs, mfcc_stats = extract_normalized_mfccs(dat)
    normalized_envelope, envelope_stats = extract_normalized_envelope(dat)

    # Generate speech labels
    restricted_labels, unrestricted_labels = create_speech_labels(
        dat, 
        threshold_factor=0.1,
        min_speech_duration=3,
        min_silence_duration=3,
        merge_window=10
    )

    # Create a mapping from block number to block type string
    block_map = {}
    for i in range(len(dat['blockList'])):
        block_id = dat['blockList'][i][0]
        block_type = dat['blockTypes'][i][0][0]  # Extract string from nested array
        block_map[block_id] = block_type

    # Lists to store trial data
    tx1s = []
    spikePows = []
    neuralLens = []
    sentences = []
    phonemes = []
    phonemeIDs_list = []
    phoneLens = []
    audioWaveforms = []
    mfccs = []
    audioEnvelopes = []
    speechLabels = []
    goPeriodOnset_list = []
    blockTypes_list = []  # Store block type strings
    blockIds_list = []    # Store block numeric IDs
    
    # Process each trial
    for i in range(n_trials):   
        # Process sentence and phonemes
        sentence = dat['sentences'][i][0][0].strip()
        phoneme_list, phonemeIDs = phonemesFromSentence(sentence)

        # Determine time boundaries based on requested part
        if trial_part == "go":
            (start, end) = goTrialEpochs[i]
            goPeriodOnset = 0  # Go period starts at beginning of extracted data
        elif trial_part == "full":
            (start, end) = delayAndGoEpochs[i]
            # Calculate go period onset relative to trial start
            goPeriodOnset = goTrialEpochs[i][0] - delayAndGoEpochs[i][0]
        else:
            raise ValueError(f"Unknown trial_part: {trial_part}. Use 'full' or 'go'.")
        

        # Extract time series data for this trial
        tx = dat["tx1"][start:end + 1, :]
        spikePow = dat["spikePow"][start:end + 1, :]
        mfcc = normalized_mfccs[start:end + 1, :]
        audioEnvelope = normalized_envelope[start:end + 1]
        speechLabel = unrestricted_labels[start:end + 1]
        audioWaveform = dat["audio"][start:end + 1]

        # Determine block type for this trial
        block_nums_in_trial = dat['blockNum'][start:end+1, 0]
        block_id = np.bincount(block_nums_in_trial).argmax()  # Most common block number
        block_type = block_map[block_id]

        # Store all features
        tx1s.append(tx)
        spikePows.append(spikePow)
        neuralLens.append(tx.shape[0])
        sentences.append(sentence)
        phonemes.append(phoneme_list)
        phonemeIDs_list.append(phonemeIDs)
        phoneLens.append(len(phoneme_list))
        audioWaveforms.append(audioWaveform)
        mfccs.append(mfcc)
        audioEnvelopes.append(audioEnvelope)
        speechLabels.append(speechLabel)
        goPeriodOnset_list.append(goPeriodOnset)
        blockTypes_list.append(block_type)
        blockIds_list.append(block_id)

    # Create session data dictionary
    session_data = {
        'tx1': tx1s,
        'spikePow': spikePows,
        'neuralLens': neuralLens,
        'transcriptions': sentences,
        "phonemes": phonemes,
        "phonemeIDs": phonemeIDs_list,
        'phoneLens': phoneLens,
        "audioWaveform": audioWaveforms,
        'mfcc': mfccs,
        'audioEnvelope': audioEnvelopes,
        'speechLabel': speechLabels,
        'goPeriodOnset': goPeriodOnset_list,
        'blockIdx': blockIds_list,
        'blockType': blockTypes_list,
        'sessionName': session_name_list,
        'speakingMode': speakingMode_list,
    }
    
    audio_stats = {
        'mfcc': mfcc_stats,
        'audioEnvelope': envelope_stats
    }
    return session_data, audio_stats

def normalizeFeatures(session_data, feature_key):
    """
    Perform block-wise normalization on the selected feature set.
    
    Args:
        session_data: Dictionary containing features and block information
        feature_key: Key to access the feature set to normalize
        
    Returns:
        tuple: (Updated session_data dict, dict with block-wise normalization stats)
    """
    features = session_data[feature_key].copy()
    blockNums = np.array(session_data['blockIdx'])
    
    # Identify unique blocks
    blockList = np.unique(blockNums)
    
    normalization_stats = {}
    # Normalize features block-by-block
    for block_id in blockList:
        # Get indices for all trials in this block
        sentIdx = np.where(blockNums == block_id)[0]

        if len(sentIdx) == 0:
            continue
        
        # Concatenate all features in this block
        feats_to_cat = [features[i] for i in sentIdx]
        feats = np.concatenate(feats_to_cat, axis=0)
        
        # Calculate mean and std for this block
        feats_mean = np.mean(feats, axis=0, keepdims=True)
        feats_std = np.std(feats, axis=0, keepdims=True)
        
        normalization_stats[int(block_id)] = {'mean': feats_mean, 'std': feats_std}
        
        # Normalize each trial in this block
        for i in sentIdx:
            features[i] = (features[i] - feats_mean) / (feats_std + 1e-8)
    
    # Update session_data with normalized features
    session_data[feature_key] = features
    
    return session_data, normalization_stats

def loadFeaturesAndPreprocess(sessionPath, trial_part="full"):
    """
    Load features and normalize neural data.
    
    Args:
        sessionPath: Path to the .mat file
        trial_part: Which part of trial to extract ("full" or "go")
        
    Returns:
        tuple: (dict containing normalized features, dict containing all normalization stats)
    """
    # Load raw features
    session_data, audio_stats = loadFeatures(sessionPath, trial_part)
    
    session_stats = {'audio': audio_stats}

    # # Normalize neural features block-wise
    session_data, tx1_stats = normalizeFeatures(session_data, "tx1")
    session_data, spikePow_stats = normalizeFeatures(session_data, "spikePow")
    
    session_stats['neural'] = {
        'tx1': tx1_stats,
        'spikePow': spikePow_stats
    }

    return session_data, session_stats



# --- 6. TRAIN, TEST SET HANDLING FUNCTIONS ---

def _load_trial_definitions(mat_file_path, session_name, trial_type_str):
    """
    Helper function to load trial definitions (block ID, sentence) from a .mat file.
    """
    if not os.path.exists(mat_file_path):
        print(f"Warning: {trial_type_str} trial definition file not found for session {session_name} at {mat_file_path}")
        return [] # Return empty list if file doesn't exist

    data = scipy.io.loadmat(mat_file_path)
    
    if "blockIdx" not in data or "sentenceText" not in data:
        print(f"Warning: 'blockIdx' or 'sentenceText' not found in {trial_type_str} trial file: {mat_file_path}")
        return []

    bIdxs = [np.int64(idx[0]) for idx in data["blockIdx"]]
    sentenceTxts = data["sentenceText"]
    
    defined_trials = [(bIdx, _clean_sentence_for_matching(sentenceTxt)) for bIdx, sentenceTxt in zip(bIdxs, sentenceTxts)]
    return defined_trials

def get_defined_trials_dict(dataDir, sessionNames, definition_type):
    """
    Get dictionary of defined trials (test or train) for each session.
    Trials are loaded from files in dataDir/competitionData/{definition_type}/sessionName.mat
    
    Args:
        dataDir: Base data directory
        sessionNames: List of session names
        definition_type: String, either "test" or "train" to specify subfolder
        
    Returns:
        dict: Dictionary mapping session names to a list of (block_id, sentence) tuples
    """
    if definition_type not in ["test", "train"]:
        raise ValueError("definition_type must be 'test' or 'train'")

    trials_definition_dict = {}
    print(f"Loading {definition_type} trial definitions...")

    for session in sessionNames:
        file_path = os.path.join(dataDir, "competitionData", definition_type, session + ".mat")
        trials_definition_dict[session] = _load_trial_definitions(file_path, session, definition_type)
        
    return trials_definition_dict

def identify_trials_from_definitions(session_idx_sentence_list, defined_trials, trial_type_str):
    """
    Identify which trials in session data correspond to a list of defined trials.
    Matches based on (block_id, sentence) using a one-to-one mapping.
    
    Args:
        session_idx_sentence_list: List of (block_id, sentence) tuples for all trials in the session.
        defined_trials: List of (block_id, sentence) tuples for the trials to identify (e.g., test or train).
        trial_type_str: String descriptor for the trial type (e.g., "test", "train") for logging.
        
    Returns:
        identified_indices: List of indices in session_idx_sentence_list that match defined_trials.
    """
    if not defined_trials:
        print(f"No {trial_type_str} trials defined or loaded. Returning empty list of indices.")
        return []

    # Use a Counter to handle duplicate trial definitions correctly.
    # The definitions in `defined_trials` are already cleaned.
    definitions_to_find = Counter(defined_trials)
    original_definition_count = sum(definitions_to_find.values())
    
    identified_indices = []
    found_definitions = Counter()

    # Iterate through each trial from the session data to find matches.
    for i, (session_block_id, session_sentence) in enumerate(session_idx_sentence_list):
        cleaned_session_sentence = _clean_sentence_for_matching(session_sentence)
        current_trial_tuple = (session_block_id, cleaned_session_sentence)
        
        # Check if this trial is one we are looking for and we haven't found enough of them yet.
        if definitions_to_find[current_trial_tuple] > 0:
            identified_indices.append(i)
            definitions_to_find[current_trial_tuple] -= 1  # Consume one definition.
            found_definitions[current_trial_tuple] += 1

    # Logging and validation
    found_definition_count = sum(found_definitions.values())

    if found_definition_count < original_definition_count:
        print(f"Warning: Not all {trial_type_str} trial definitions were found in session data.")
        print(f"  Found {found_definition_count} out of {original_definition_count} total {trial_type_str} definitions.")
        
        # Determine which specific definitions were missed.
        missing_definitions = Counter(defined_trials) - found_definitions
        if missing_definitions:
            print(f"  Missing {trial_type_str} definitions (count, block, 'cleaned sentence'):")
            # Sort for consistent output. Items are (element, count).
            for (block, sent), count in sorted(missing_definitions.items()):
                print(f"    - {count}x, Block {block}, '{sent}'")

    elif original_definition_count > 0:
        print(f"Success: All {original_definition_count} {trial_type_str} trial definitions were found in session data.")
    
    # This final check guards against logic errors where we might over-match.
    if len(identified_indices) != found_definition_count:
         print(f"CRITICAL WARNING: The number of identified indices ({len(identified_indices)}) does not match "
               f"the number of found definitions ({found_definition_count}). This indicates a logic error.")

    return identified_indices


def split_session_data(session_data, test_trial_definitions, train_trial_definitions):
    """
    Split session data into training and testing sets based on explicit definitions.
    
    Args:
        session_data: Dictionary containing all data for a session.
        test_trial_definitions: List of (block_id, sentence) tuples for test trials.
        train_trial_definitions: List of (block_id, sentence) tuples for train trials.
        
    Returns:
        tuple: (train_data, test_data)
    """
    # Create list of (block_id, sentence) pairs from session_data for matching
    session_idx_sentence_list = []
    if 'blockIdx' in session_data and 'transcriptions' in session_data:
        for idx, sentence in zip(session_data["blockIdx"], session_data["transcriptions"]):
            session_idx_sentence_list.append((idx, str(sentence).strip())) # Keep original for logging
    else:
        print("Error: 'blockIdx' or 'transcriptions' not found in session_data. Cannot split.")
        # Return empty dicts or raise error
        return {}, {}

    # Identify test trial indices
    print("\nIdentifying Test Trials:")
    test_indices = identify_trials_from_definitions(session_idx_sentence_list, test_trial_definitions, "test")
    
    # Identify train trial indices
    print("\nIdentifying Train Trials:")
    train_indices = identify_trials_from_definitions(session_idx_sentence_list, train_trial_definitions, "train")

    # Critical Check: Ensure no overlap between train and test indices
    if test_indices and train_indices: # Only check if both lists have items
        overlap = set(test_indices) & set(train_indices)
        if overlap:
            print(f"\nCRITICAL WARNING: Overlap found between explicitly defined train and test trials!")
            print(f"  Overlap indices in session_data: {sorted(list(overlap))}")
            print(f"  Overlapping trials (Block, Sentence):")
            for idx in sorted(list(overlap)):
                print(f"    - {session_idx_sentence_list[idx]}")
            # Resolution strategy needed: e.g., remove from train, raise error.
            # For now, raising an error as this is usually unintended.
            raise ValueError("Overlap detected between explicit train and test sets. Please check definitions.")

    # Create empty dictionaries for train and test data
    train_data = {}
    test_data = {}
    
    num_session_trials = len(session_data.get('transcriptions', []))

    # Process each key in session_data
    for key, value in session_data.items():
        # Check if the value is a list or array corresponding to per-trial data
        is_per_trial_list = isinstance(value, list) and len(value) == num_session_trials
        is_per_trial_array = isinstance(value, np.ndarray) and value.shape[0] == num_session_trials
        
        if is_per_trial_list:
            train_data[key] = [value[i] for i in train_indices]
            test_data[key] = [value[i] for i in test_indices]
        elif is_per_trial_array:
            train_data[key] = value[train_indices]
            test_data[key] = value[test_indices]
        else:
            # This is a global value (or not per-trial) - copy to both if it's not None
            # or handle based on specific needs. For most metadata, copying is fine.
            train_data[key] = value 
            test_data[key] = value 
            if num_session_trials == 0 and (is_per_trial_list or is_per_trial_array) : # handles empty session data
                 train_data[key] = [] if is_per_trial_list else np.array([])
                 test_data[key] = [] if is_per_trial_list else np.array([])


    print(f"\nSplit complete: {len(train_indices)} train trials, {len(test_indices)} test trials.")
    if not train_indices and train_trial_definitions:
        print("Warning: Train set is empty despite train definitions being provided.")
    if not test_indices and test_trial_definitions:
        print("Warning: Test set is empty despite test definitions being provided.")
        
    return train_data, test_data




# --- 7. MAIN SCRIPT EXECUTION ---
import argparse

def main(dataDir="", trial_part="full", output_file="dataset_gen2"):
    """    
    Args:
        dataDir: Base data directory
        trial_part: Which part of trial to extract ("full" for delay + go periods or "go" for go periods only)
        output_file: Name of the output dataset file
    """
    # Define session names
    sessionNames = [
        't12.2022.04.28', 't12.2022.05.26', 't12.2022.06.21', 't12.2022.07.21', 't12.2022.08.13',
        't12.2022.05.05', 't12.2022.06.02', 't12.2022.06.23', 't12.2022.07.27', 't12.2022.08.18',
        't12.2022.05.17', 't12.2022.06.07', 't12.2022.06.28', 't12.2022.07.29', 't12.2022.08.23',
        't12.2022.05.19', 't12.2022.06.14', 't12.2022.07.05', 't12.2022.08.02', 't12.2022.08.25',
        't12.2022.05.24', 't12.2022.06.16', 't12.2022.07.14', 't12.2022.08.11'
    ]
    sessionNames.sort()
    
    print(f"Using data directory: {dataDir}")
    print(f"Trial part: {trial_part}")
    
    # Get test trials for all sessions
    print("Loading test trial information...")
    test_trials_dict = get_defined_trials_dict(dataDir, sessionNames, "test")
    train_trials_dict = get_defined_trials_dict(dataDir, sessionNames, "train")
    
    # Create day index mapping for sessions
    sessionToDayIdxDict = {sessionName: idx for idx, sessionName in enumerate(sessionNames)}
    
    # Lists to store processed datasets
    trainDatasets = []
    testDatasets = []
    all_zscoring_info = {}
    
    # Process each session
    for idx, sessionName in enumerate(sessionNames):
        print(f"Processing session {idx+1}/{len(sessionNames)}: {sessionName}...")
        
        # Load and preprocess data
        fullDataPath = os.path.join(dataDir, "sentences", sessionName + "_sentences" + ".mat")
        session_data, session_zscoring_info = loadFeaturesAndPreprocess(fullDataPath, trial_part)
        all_zscoring_info[sessionName] = session_zscoring_info
        
        # Split into train/test sets
        test_defs_for_session = test_trials_dict.get(sessionName, [])
        train_defs_for_session = train_trials_dict.get(sessionName, [])
        
        train_data, test_data = split_session_data(session_data, 
                                                test_defs_for_session,
                                                train_defs_for_session)
        
        # Add to dataset collections
        trainDatasets.append(train_data)
        testDatasets.append(test_data)
        
        print(f"Completed {sessionName}: {len(train_data['transcriptions'])} train trials, "
              f"{len(test_data['transcriptions'])} test trials\n")
        print()
    
    # Calculate and print final totals
    total_train_trials = sum(len(d['transcriptions']) for d in trainDatasets if d and 'transcriptions' in d)
    total_test_trials = sum(len(d['transcriptions']) for d in testDatasets if d and 'transcriptions' in d)

    print("-" * 50)
    print(f"Final Count Across All Sessions:")
    print(f"  Total Train Trials: {total_train_trials}")
    print(f"  Total Test Trials:  {total_test_trials}")
    print("-" * 50)
    print()

    # Combine all datasets into a single structure
    print("Creating final dataset structure...")
    allDatasets = {
        'train': trainDatasets,
        'test': testDatasets,
        'phonemeToId': phoneToIdDict,
        'idToPhoneme': idToPhone,
        'sessionToDayIdx': sessionToDayIdxDict
    }
    
    # Save the dataset
    print(f"Saving dataset to disk as '{output_file}'...")
    with open(output_file, 'wb') as handle:
        pickle.dump(allDatasets, handle)
    
    # Save z-scoring info
    zscoring_output_file = f"willet_dataset_gen2_{trial_part}_zscoring_info.pkl"
    print(f"Saving z-scoring info to disk as '{zscoring_output_file}'...")
    with open(zscoring_output_file, 'wb') as handle:
        pickle.dump(all_zscoring_info, handle)

    print("Dataset creation complete!")


# Parse command line arguments
if __name__ == "__main__":

    dataDir = ""
    trial_part = "full" # "full" means delay + go periods or "go" means go periods only
    output_file = f"willet_dataset_gen2_{trial_part}_fixed_trials"
    
    # Call main with parsed arguments
    main(dataDir, trial_part, output_file)
